mirror of
https://github.com/tailscale/tailscale.git
synced 2025-09-21 05:31:36 +02:00
Add a ternary flag that unless set explicitly to false keeps the insecure behavior of TSIDP. If the flag is false, add functionality on startup to migrate oidc-funnel-clients.json to oauth-clients.json if it doesn’t exist. If the flag is false, modify endpoints to behave similarly regardless of funnel, tailnet, or localhost. They will all verify client ID & secret when appropriate per RFC 6749. The authorize endpoint will no longer change based on funnel status or nodeID. Add extra tests verifying TSIDP endpoints behave as expected with the new flag. Safely create the redirect URL from what's passed into the authorize endpoint. Fixes #16880 Signed-off-by: Remy Guercio <remy@tailscale.com>
2064 lines
57 KiB
Go
2064 lines
57 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
// Package main tests for tsidp focus on OAuth security boundaries and
|
|
// correct implementation of the OpenID Connect identity provider.
|
|
//
|
|
// Test Strategy:
|
|
// - Tests are intentionally granular to provide clear failure signals when
|
|
// security-critical logic breaks
|
|
// - OAuth flow tests cover both strict mode (registered clients only) and
|
|
// legacy mode (local funnel clients) to ensure proper access controls
|
|
// - Helper functions like normalizeMap ensure deterministic comparisons
|
|
// despite JSON marshaling order variations
|
|
// - The privateKey global is reused across tests for performance (RSA key
|
|
// generation is expensive)
|
|
|
|
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/netip"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"gopkg.in/square/go-jose.v2"
|
|
"gopkg.in/square/go-jose.v2/jwt"
|
|
"tailscale.com/client/local"
|
|
"tailscale.com/client/tailscale/apitype"
|
|
"tailscale.com/tailcfg"
|
|
"tailscale.com/types/key"
|
|
"tailscale.com/types/opt"
|
|
"tailscale.com/types/views"
|
|
)
|
|
|
|
// normalizeMap recursively sorts []any values in a map[string]any to ensure
|
|
// deterministic test comparisons. This is necessary because JSON marshaling
|
|
// doesn't guarantee array order, and we need stable comparisons when testing
|
|
// claim merging and flattening logic.
|
|
func normalizeMap(t *testing.T, m map[string]any) map[string]any {
|
|
t.Helper()
|
|
normalized := make(map[string]any, len(m))
|
|
for k, v := range m {
|
|
switch val := v.(type) {
|
|
case []any:
|
|
sorted := make([]string, len(val))
|
|
for i, item := range val {
|
|
sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting
|
|
}
|
|
sort.Strings(sorted)
|
|
|
|
// convert back to []any
|
|
sortedIface := make([]any, len(sorted))
|
|
for i, s := range sorted {
|
|
sortedIface[i] = s
|
|
}
|
|
normalized[k] = sortedIface
|
|
|
|
default:
|
|
normalized[k] = v
|
|
}
|
|
}
|
|
return normalized
|
|
}
|
|
|
|
func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage {
|
|
t.Helper()
|
|
b, err := json.Marshal(v)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return tailcfg.RawMessage(b)
|
|
}
|
|
|
|
// privateKey is a shared RSA private key used across tests. It's lazily
|
|
// initialized on first use to avoid the expensive key generation cost
|
|
// for every test. Protected by privateKeyMu for thread safety.
|
|
var (
|
|
privateKey *rsa.PrivateKey
|
|
privateKeyMu sync.Mutex
|
|
)
|
|
|
|
func oidcTestingSigner(t *testing.T) jose.Signer {
|
|
t.Helper()
|
|
privKey := mustGeneratePrivateKey(t)
|
|
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil)
|
|
if err != nil {
|
|
t.Fatalf("failed to create signer: %v", err)
|
|
}
|
|
return sig
|
|
}
|
|
|
|
func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey {
|
|
t.Helper()
|
|
privKey := mustGeneratePrivateKey(t)
|
|
return &privKey.PublicKey
|
|
}
|
|
|
|
func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey {
|
|
t.Helper()
|
|
privateKeyMu.Lock()
|
|
defer privateKeyMu.Unlock()
|
|
|
|
if privateKey != nil {
|
|
return privateKey
|
|
}
|
|
|
|
var err error
|
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatalf("failed to generate key: %v", err)
|
|
}
|
|
|
|
return privateKey
|
|
}
|
|
|
|
func TestFlattenExtraClaims(t *testing.T) {
|
|
log.SetOutput(io.Discard) // suppress log output during tests
|
|
|
|
tests := []struct {
|
|
name string
|
|
input []capRule
|
|
expected map[string]any
|
|
}{
|
|
{
|
|
name: "empty extra claims",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{}},
|
|
},
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "string and number values",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"featureA": "read",
|
|
"featureB": 42,
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"featureA": "read",
|
|
"featureB": "42",
|
|
},
|
|
},
|
|
{
|
|
name: "slice of strings and ints",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"roles": []any{"admin", "user", 1},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"roles": []any{"admin", "user", "1"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate values deduplicated (slice input)",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar", "baz"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []any{"bar", "qux"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "baz", "qux"},
|
|
},
|
|
},
|
|
{
|
|
name: "ignore unsupported map type, keep valid scalar",
|
|
input: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"invalid": map[string]any{"bad": "yes"},
|
|
"valid": "ok",
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"valid": "ok",
|
|
},
|
|
},
|
|
{
|
|
name: "scalar first, slice second",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"foo": "bar"}},
|
|
{ExtraClaims: map[string]any{"foo": []any{"baz"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "baz"}, // converts to slice when any rule provides a slice
|
|
},
|
|
},
|
|
{
|
|
name: "conflicting scalar and unsupported map",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"foo": "bar"}},
|
|
{ExtraClaims: map[string]any{"foo": map[string]any{"bad": "entry"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar", // map should be ignored
|
|
},
|
|
},
|
|
{
|
|
name: "multiple slices with overlap",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"roles": []any{"admin", "user"}}},
|
|
{ExtraClaims: map[string]any{"roles": []any{"admin", "guest"}}},
|
|
},
|
|
expected: map[string]any{
|
|
"roles": []any{"admin", "user", "guest"},
|
|
},
|
|
},
|
|
{
|
|
name: "slice with unsupported values",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{
|
|
"mixed": []any{"ok", 42, map[string]string{"oops": "fail"}},
|
|
}},
|
|
},
|
|
expected: map[string]any{
|
|
"mixed": []any{"ok", "42"}, // map is ignored
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate scalar value",
|
|
input: []capRule{
|
|
{ExtraClaims: map[string]any{"env": "prod"}},
|
|
{ExtraClaims: map[string]any{"env": "prod"}},
|
|
},
|
|
expected: map[string]any{
|
|
"env": "prod", // not converted to slice
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := flattenExtraClaims(tt.input)
|
|
|
|
gotNormalized := normalizeMap(t, got)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraClaims(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
claim tailscaleClaims
|
|
extraClaims []capRule
|
|
expected map[string]any
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "extra claim",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate claim distinct values",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"foobar"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"foobar", "bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple extra claims",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
},
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"bar": []string{"foo"},
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
"foo": []any{"bar"},
|
|
"bar": []any{"foo"},
|
|
},
|
|
},
|
|
{
|
|
name: "overwrite claim",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{
|
|
{
|
|
ExtraClaims: map[string]any{
|
|
"username": "foobar",
|
|
},
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "foobar",
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "empty extra claims",
|
|
claim: tailscaleClaims{
|
|
Claims: jwt.Claims{},
|
|
Nonce: "foobar",
|
|
Key: key.NodePublic{},
|
|
Addresses: views.Slice[netip.Prefix]{},
|
|
NodeID: 0,
|
|
NodeName: "test-node",
|
|
Tailnet: "test.ts.net",
|
|
Email: "test@example.com",
|
|
UserID: 0,
|
|
UserName: "test",
|
|
},
|
|
extraClaims: []capRule{{ExtraClaims: map[string]any{}}},
|
|
expected: map[string]any{
|
|
"nonce": "foobar",
|
|
"key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000",
|
|
"addresses": nil,
|
|
"nid": float64(0),
|
|
"node": "test-node",
|
|
"tailnet": "test.ts.net",
|
|
"email": "test@example.com",
|
|
"username": "test",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
claims, err := withExtraClaims(tt.claim, tt.extraClaims)
|
|
if err != nil && !tt.expectError {
|
|
t.Fatalf("claim.withExtraClaims() unexpected error = %v", err)
|
|
} else if err == nil && tt.expectError {
|
|
t.Fatalf("expected error, got nil")
|
|
} else if err != nil && tt.expectError {
|
|
return // just as expected
|
|
}
|
|
|
|
// Marshal to JSON then unmarshal back to map[string]any
|
|
gotClaims, err := json.Marshal(claims)
|
|
if err != nil {
|
|
t.Errorf("json.Marshal(claims) error = %v", err)
|
|
}
|
|
|
|
var gotClaimsMap map[string]any
|
|
if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil {
|
|
t.Fatalf("json.Unmarshal(gotClaims) error = %v", err)
|
|
}
|
|
|
|
gotNormalized := normalizeMap(t, gotClaimsMap)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServeToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
caps tailcfg.PeerCapMap
|
|
method string
|
|
grantType string
|
|
code string
|
|
omitCode bool
|
|
redirectURI string
|
|
remoteAddr string
|
|
strictMode bool
|
|
expectError bool
|
|
expected map[string]any
|
|
}{
|
|
{
|
|
name: "GET not allowed",
|
|
method: "GET",
|
|
grantType: "authorization_code",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "unsupported grant type",
|
|
method: "POST",
|
|
grantType: "pkcs",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid code",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "invalid-code",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "omit code from form",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
omitCode: true,
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid redirect uri",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://invalid.example.com/callback",
|
|
remoteAddr: "127.0.0.1:12345",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid remoteAddr",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
remoteAddr: "192.168.0.1:12345",
|
|
strictMode: false,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "extra claim included (non-strict)",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
remoteAddr: "127.0.0.1:12345",
|
|
strictMode: false,
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "bar",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar",
|
|
},
|
|
},
|
|
{
|
|
name: "attempt to overwrite protected claim (non-strict)",
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
code: "valid-code",
|
|
strictMode: false,
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"sub": "should-not-overwrite",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
now := time.Now()
|
|
|
|
// Use setupTestServer helper
|
|
s := setupTestServer(t, tt.strictMode)
|
|
|
|
// Fake user/node
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tt.caps,
|
|
}
|
|
|
|
// Setup auth request with appropriate configuration for strict mode
|
|
var funnelClientPtr *funnelClient
|
|
if tt.strictMode {
|
|
funnelClientPtr = &funnelClient{
|
|
ID: "client-id",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
s.funnelClients["client-id"] = funnelClientPtr
|
|
}
|
|
|
|
s.code["valid-code"] = &authRequest{
|
|
clientID: "client-id",
|
|
nonce: "nonce123",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
validTill: now.Add(5 * time.Minute),
|
|
remoteUser: remoteUser,
|
|
localRP: !tt.strictMode,
|
|
funnelRP: funnelClientPtr,
|
|
}
|
|
|
|
form := url.Values{}
|
|
form.Set("grant_type", tt.grantType)
|
|
form.Set("redirect_uri", tt.redirectURI)
|
|
if !tt.omitCode {
|
|
form.Set("code", tt.code)
|
|
}
|
|
// Add client credentials for strict mode
|
|
if tt.strictMode {
|
|
form.Set("client_id", "client-id")
|
|
form.Set("client_secret", "test-secret")
|
|
}
|
|
|
|
req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode()))
|
|
req.RemoteAddr = tt.remoteAddr
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
rr := httptest.NewRecorder()
|
|
|
|
s.serveToken(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code == http.StatusOK {
|
|
t.Fatalf("expected error, got 200 OK: %s", rr.Body.String())
|
|
}
|
|
return
|
|
}
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp struct {
|
|
IDToken string `json:"id_token"`
|
|
}
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
tok, err := jwt.ParseSigned(resp.IDToken)
|
|
if err != nil {
|
|
t.Fatalf("failed to parse ID token: %v", err)
|
|
}
|
|
|
|
out := make(map[string]any)
|
|
if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil {
|
|
t.Fatalf("failed to extract claims: %v", err)
|
|
}
|
|
|
|
for k, want := range tt.expected {
|
|
got, ok := out[k]
|
|
if !ok {
|
|
t.Errorf("missing expected claim %q", k)
|
|
continue
|
|
}
|
|
if !reflect.DeepEqual(got, want) {
|
|
t.Errorf("claim %q: got %v, want %v", k, got, want)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtraUserInfo(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
caps tailcfg.PeerCapMap
|
|
tokenValidTill time.Time
|
|
expected map[string]any
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "extra claim",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar"},
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar"},
|
|
},
|
|
},
|
|
{
|
|
name: "duplicate claim distinct values",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": []string{"bar", "foobar"},
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": []any{"bar", "foobar"},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple extra claims",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "bar",
|
|
"bar": "foo",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{
|
|
"foo": "bar",
|
|
"bar": "foo",
|
|
},
|
|
},
|
|
{
|
|
name: "empty extra claims",
|
|
caps: tailcfg.PeerCapMap{},
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "attempt to overwrite protected claim",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: true,
|
|
ExtraClaims: map[string]any{
|
|
"sub": "should-not-overwrite",
|
|
"foo": "ok",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "extra claim omitted",
|
|
tokenValidTill: time.Now().Add(1 * time.Minute),
|
|
caps: tailcfg.PeerCapMap{
|
|
tailcfg.PeerCapabilityTsIDP: {
|
|
mustMarshalJSON(t, capRule{
|
|
IncludeInUserInfo: false,
|
|
ExtraClaims: map[string]any{
|
|
"foo": "ok",
|
|
},
|
|
}),
|
|
},
|
|
},
|
|
expected: map[string]any{},
|
|
},
|
|
{
|
|
name: "expired token",
|
|
caps: tailcfg.PeerCapMap{},
|
|
tokenValidTill: time.Now().Add(-1 * time.Minute),
|
|
expected: map[string]any{},
|
|
expectError: true,
|
|
},
|
|
}
|
|
token := "valid-token"
|
|
|
|
// Create a fake tailscale Node
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
// Construct the remote user
|
|
profile := tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: &profile,
|
|
CapMap: tt.caps,
|
|
}
|
|
|
|
// Insert a valid token into the idpServer
|
|
s := &idpServer{
|
|
allowInsecureRegistration: true, // Default to allowing insecure registration for backward compatibility
|
|
accessToken: map[string]*authRequest{
|
|
token: {
|
|
validTill: tt.tokenValidTill,
|
|
remoteUser: remoteUser,
|
|
},
|
|
},
|
|
}
|
|
|
|
// Construct request
|
|
req := httptest.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rr := httptest.NewRecorder()
|
|
|
|
// Call the method under test
|
|
s.serveUserInfo(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code == http.StatusOK {
|
|
t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
return
|
|
}
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// Construct expected
|
|
tt.expected["sub"] = remoteUser.Node.User.String()
|
|
tt.expected["name"] = profile.DisplayName
|
|
tt.expected["email"] = profile.LoginName
|
|
tt.expected["picture"] = profile.ProfilePicURL
|
|
tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@")
|
|
|
|
gotNormalized := normalizeMap(t, resp)
|
|
expectedNormalized := normalizeMap(t, tt.expected)
|
|
|
|
if !reflect.DeepEqual(gotNormalized, expectedNormalized) {
|
|
t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestFunnelClientsPersistence(t *testing.T) {
|
|
testClients := map[string]*funnelClient{
|
|
"test-client-1": {
|
|
ID: "test-client-1",
|
|
Secret: "test-secret-1",
|
|
Name: "Test Client 1",
|
|
RedirectURI: "https://example.com/callback",
|
|
},
|
|
"test-client-2": {
|
|
ID: "test-client-2",
|
|
Secret: "test-secret-2",
|
|
Name: "Test Client 2",
|
|
RedirectURI: "https://example2.com/callback",
|
|
},
|
|
}
|
|
|
|
testData, err := json.Marshal(testClients)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal test data: %v", err)
|
|
}
|
|
|
|
tmpFile := t.TempDir() + "/oidc-funnel-clients.json"
|
|
if err := os.WriteFile(tmpFile, testData, 0600); err != nil {
|
|
t.Fatalf("failed to write test file: %v", err)
|
|
}
|
|
|
|
t.Run("load_from_existing_file", func(t *testing.T) {
|
|
srv := &idpServer{}
|
|
|
|
// Simulate the funnel clients loading logic from main()
|
|
srv.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(tmpFile)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", tmpFile, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", tmpFile, err)
|
|
}
|
|
|
|
// Verify clients were loaded correctly
|
|
if len(srv.funnelClients) != 2 {
|
|
t.Errorf("expected 2 clients, got %d", len(srv.funnelClients))
|
|
}
|
|
|
|
client1, ok := srv.funnelClients["test-client-1"]
|
|
if !ok {
|
|
t.Error("expected test-client-1 to be loaded")
|
|
} else {
|
|
if client1.Name != "Test Client 1" {
|
|
t.Errorf("expected client name 'Test Client 1', got '%s'", client1.Name)
|
|
}
|
|
if client1.Secret != "test-secret-1" {
|
|
t.Errorf("expected client secret 'test-secret-1', got '%s'", client1.Secret)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("initialize_empty_when_no_file", func(t *testing.T) {
|
|
nonExistentFile := t.TempDir() + "/non-existent.json"
|
|
|
|
srv := &idpServer{}
|
|
|
|
// Simulate the funnel clients loading logic from main()
|
|
srv.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(nonExistentFile)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", nonExistentFile, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", nonExistentFile, err)
|
|
}
|
|
|
|
// Verify map is initialized but empty
|
|
if srv.funnelClients == nil {
|
|
t.Error("expected funnelClients map to be initialized")
|
|
}
|
|
if len(srv.funnelClients) != 0 {
|
|
t.Errorf("expected empty map, got %d clients", len(srv.funnelClients))
|
|
}
|
|
})
|
|
|
|
t.Run("persist_and_reload_clients", func(t *testing.T) {
|
|
tmpFile2 := t.TempDir() + "/test-persistence.json"
|
|
|
|
// Create initial server with one client
|
|
srv1 := &idpServer{
|
|
funnelClients: make(map[string]*funnelClient),
|
|
}
|
|
srv1.funnelClients["new-client"] = &funnelClient{
|
|
ID: "new-client",
|
|
Secret: "new-secret",
|
|
Name: "New Client",
|
|
RedirectURI: "https://new.example.com/callback",
|
|
}
|
|
|
|
// Save clients to file (simulating saveFunnelClients)
|
|
data, err := json.Marshal(srv1.funnelClients)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal clients: %v", err)
|
|
}
|
|
if err := os.WriteFile(tmpFile2, data, 0600); err != nil {
|
|
t.Fatalf("failed to write clients file: %v", err)
|
|
}
|
|
|
|
// Create new server instance and load clients
|
|
srv2 := &idpServer{}
|
|
srv2.funnelClients = make(map[string]*funnelClient)
|
|
f, err := os.Open(tmpFile2)
|
|
if err == nil {
|
|
if err := json.NewDecoder(f).Decode(&srv2.funnelClients); err != nil {
|
|
t.Fatalf("could not parse %s: %v", tmpFile2, err)
|
|
}
|
|
f.Close()
|
|
} else if !errors.Is(err, os.ErrNotExist) {
|
|
t.Fatalf("could not open %s: %v", tmpFile2, err)
|
|
}
|
|
|
|
// Verify the client was persisted correctly
|
|
loadedClient, ok := srv2.funnelClients["new-client"]
|
|
if !ok {
|
|
t.Error("expected new-client to be loaded after persistence")
|
|
} else {
|
|
if loadedClient.Name != "New Client" {
|
|
t.Errorf("expected client name 'New Client', got '%s'", loadedClient.Name)
|
|
}
|
|
if loadedClient.Secret != "new-secret" {
|
|
t.Errorf("expected client secret 'new-secret', got '%s'", loadedClient.Secret)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("strict_mode_file_handling", func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
// Test strict mode uses oauth-clients.json
|
|
srv1 := setupTestServer(t, true)
|
|
srv1.rootPath = tmpDir
|
|
srv1.funnelClients["oauth-client"] = &funnelClient{
|
|
ID: "oauth-client",
|
|
Secret: "oauth-secret",
|
|
Name: "OAuth Client",
|
|
RedirectURI: "https://oauth.example.com/callback",
|
|
}
|
|
|
|
// Test storeFunnelClientsLocked in strict mode
|
|
srv1.mu.Lock()
|
|
err := srv1.storeFunnelClientsLocked()
|
|
srv1.mu.Unlock()
|
|
|
|
if err != nil {
|
|
t.Fatalf("failed to store clients in strict mode: %v", err)
|
|
}
|
|
|
|
// Verify oauth-clients.json was created
|
|
oauthPath := tmpDir + "/" + oauthClientsFile
|
|
if _, err := os.Stat(oauthPath); err != nil {
|
|
t.Errorf("expected oauth-clients.json to be created: %v", err)
|
|
}
|
|
|
|
// Verify oidc-funnel-clients.json was NOT created
|
|
funnelPath := tmpDir + "/" + funnelClientsFile
|
|
if _, err := os.Stat(funnelPath); !os.IsNotExist(err) {
|
|
t.Error("expected oidc-funnel-clients.json NOT to be created in strict mode")
|
|
}
|
|
})
|
|
|
|
t.Run("non_strict_mode_file_handling", func(t *testing.T) {
|
|
tmpDir := t.TempDir()
|
|
|
|
// Test non-strict mode uses oidc-funnel-clients.json
|
|
srv1 := setupTestServer(t, false)
|
|
srv1.rootPath = tmpDir
|
|
srv1.funnelClients["funnel-client"] = &funnelClient{
|
|
ID: "funnel-client",
|
|
Secret: "funnel-secret",
|
|
Name: "Funnel Client",
|
|
RedirectURI: "https://funnel.example.com/callback",
|
|
}
|
|
|
|
// Test storeFunnelClientsLocked in non-strict mode
|
|
srv1.mu.Lock()
|
|
err := srv1.storeFunnelClientsLocked()
|
|
srv1.mu.Unlock()
|
|
|
|
if err != nil {
|
|
t.Fatalf("failed to store clients in non-strict mode: %v", err)
|
|
}
|
|
|
|
// Verify oidc-funnel-clients.json was created
|
|
funnelPath := tmpDir + "/" + funnelClientsFile
|
|
if _, err := os.Stat(funnelPath); err != nil {
|
|
t.Errorf("expected oidc-funnel-clients.json to be created: %v", err)
|
|
}
|
|
|
|
// Verify oauth-clients.json was NOT created
|
|
oauthPath := tmpDir + "/" + oauthClientsFile
|
|
if _, err := os.Stat(oauthPath); !os.IsNotExist(err) {
|
|
t.Error("expected oauth-clients.json NOT to be created in non-strict mode")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test helper functions for strict OAuth mode testing
|
|
func setupTestServer(t *testing.T, strictMode bool) *idpServer {
|
|
return setupTestServerWithClient(t, strictMode, nil)
|
|
}
|
|
|
|
// setupTestServerWithClient creates a test server with an optional LocalClient.
|
|
// If lc is nil, the server will have no LocalClient (original behavior).
|
|
// If lc is provided, it will be used for WhoIs calls during testing.
|
|
func setupTestServerWithClient(t *testing.T, strictMode bool, lc *local.Client) *idpServer {
|
|
t.Helper()
|
|
|
|
srv := &idpServer{
|
|
allowInsecureRegistration: !strictMode,
|
|
code: make(map[string]*authRequest),
|
|
accessToken: make(map[string]*authRequest),
|
|
funnelClients: make(map[string]*funnelClient),
|
|
serverURL: "https://test.ts.net",
|
|
rootPath: t.TempDir(),
|
|
lc: lc,
|
|
}
|
|
|
|
// Add a test client for funnel/strict mode testing
|
|
srv.funnelClients["test-client"] = &funnelClient{
|
|
ID: "test-client",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
|
|
// Inject a working signer for token tests
|
|
srv.lazySigner.Set(oidcTestingSigner(t))
|
|
|
|
return srv
|
|
}
|
|
|
|
func TestGetAllowInsecureRegistration(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
flagSet bool
|
|
flagValue bool
|
|
expectAllowInsecureRegistration bool
|
|
}{
|
|
{
|
|
name: "flag explicitly set to false - insecure registration disabled (strict mode)",
|
|
flagSet: true,
|
|
flagValue: false,
|
|
expectAllowInsecureRegistration: false,
|
|
},
|
|
{
|
|
name: "flag explicitly set to true - insecure registration enabled",
|
|
flagSet: true,
|
|
flagValue: true,
|
|
expectAllowInsecureRegistration: true,
|
|
},
|
|
{
|
|
name: "flag unset - insecure registration enabled (default for backward compatibility)",
|
|
flagSet: false,
|
|
flagValue: false, // not used when unset
|
|
expectAllowInsecureRegistration: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Save original state
|
|
originalFlag := flagAllowInsecureRegistration
|
|
defer func() {
|
|
flagAllowInsecureRegistration = originalFlag
|
|
}()
|
|
|
|
// Set up test state by creating a new BoolFlag and setting values
|
|
var b opt.Bool
|
|
flagAllowInsecureRegistration = opt.BoolFlag{Bool: &b}
|
|
if tt.flagSet {
|
|
flagAllowInsecureRegistration.Bool.Set(tt.flagValue)
|
|
}
|
|
// Note: when tt.flagSet is false, the Bool remains unset (which is what we want)
|
|
|
|
got := getAllowInsecureRegistration()
|
|
if got != tt.expectAllowInsecureRegistration {
|
|
t.Errorf("getAllowInsecureRegistration() = %v, want %v", got, tt.expectAllowInsecureRegistration)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestMigrateOAuthClients verifies the migration from legacy funnel clients
|
|
// to OAuth clients. This migration is necessary when transitioning from
|
|
// non-strict to strict OAuth mode. The migration logic should:
|
|
// - Copy clients from oidc-funnel-clients.json to oauth-clients.json
|
|
// - Rename the old file to mark it as deprecated
|
|
// - Handle cases where files already exist or are missing
|
|
func TestMigrateOAuthClients(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
setupOldFile bool
|
|
setupNewFile bool
|
|
oldFileContent map[string]*funnelClient
|
|
newFileContent map[string]*funnelClient
|
|
expectError bool
|
|
expectNewFileExists bool
|
|
expectOldRenamed bool
|
|
}{
|
|
{
|
|
name: "migrate from old file to new file",
|
|
setupOldFile: true,
|
|
oldFileContent: map[string]*funnelClient{
|
|
"old-client": {
|
|
ID: "old-client",
|
|
Secret: "old-secret",
|
|
Name: "Old Client",
|
|
RedirectURI: "https://old.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: true,
|
|
},
|
|
{
|
|
name: "new file already exists - no migration",
|
|
setupNewFile: true,
|
|
newFileContent: map[string]*funnelClient{
|
|
"existing-client": {
|
|
ID: "existing-client",
|
|
Secret: "existing-secret",
|
|
Name: "Existing Client",
|
|
RedirectURI: "https://existing.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
{
|
|
name: "neither file exists - create empty new file",
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
{
|
|
name: "both files exist - prefer new file",
|
|
setupOldFile: true,
|
|
setupNewFile: true,
|
|
oldFileContent: map[string]*funnelClient{
|
|
"old-client": {
|
|
ID: "old-client",
|
|
Secret: "old-secret",
|
|
Name: "Old Client",
|
|
RedirectURI: "https://old.example.com/callback",
|
|
},
|
|
},
|
|
newFileContent: map[string]*funnelClient{
|
|
"new-client": {
|
|
ID: "new-client",
|
|
Secret: "new-secret",
|
|
Name: "New Client",
|
|
RedirectURI: "https://new.example.com/callback",
|
|
},
|
|
},
|
|
expectNewFileExists: true,
|
|
expectOldRenamed: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
rootPath := t.TempDir()
|
|
|
|
// Setup old file if needed
|
|
if tt.setupOldFile {
|
|
oldData, err := json.Marshal(tt.oldFileContent)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal old file content: %v", err)
|
|
}
|
|
oldPath := rootPath + "/" + funnelClientsFile
|
|
if err := os.WriteFile(oldPath, oldData, 0600); err != nil {
|
|
t.Fatalf("failed to create old file: %v", err)
|
|
}
|
|
}
|
|
|
|
// Setup new file if needed
|
|
if tt.setupNewFile {
|
|
newData, err := json.Marshal(tt.newFileContent)
|
|
if err != nil {
|
|
t.Fatalf("failed to marshal new file content: %v", err)
|
|
}
|
|
newPath := rootPath + "/" + oauthClientsFile
|
|
if err := os.WriteFile(newPath, newData, 0600); err != nil {
|
|
t.Fatalf("failed to create new file: %v", err)
|
|
}
|
|
}
|
|
|
|
// Call migrateOAuthClients
|
|
resultPath, err := migrateOAuthClients(rootPath)
|
|
|
|
if tt.expectError && err == nil {
|
|
t.Fatalf("expected error but got none")
|
|
}
|
|
if !tt.expectError && err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if tt.expectError {
|
|
return
|
|
}
|
|
|
|
// Verify result path points to oauth-clients.json
|
|
expectedPath := filepath.Join(rootPath, oauthClientsFile)
|
|
if resultPath != expectedPath {
|
|
t.Errorf("expected result path %s, got %s", expectedPath, resultPath)
|
|
}
|
|
|
|
// Verify new file exists if expected
|
|
if tt.expectNewFileExists {
|
|
if _, err := os.Stat(resultPath); err != nil {
|
|
t.Errorf("expected new file to exist at %s: %v", resultPath, err)
|
|
}
|
|
|
|
// Verify content
|
|
data, err := os.ReadFile(resultPath)
|
|
if err != nil {
|
|
t.Fatalf("failed to read new file: %v", err)
|
|
}
|
|
|
|
var clients map[string]*funnelClient
|
|
if err := json.Unmarshal(data, &clients); err != nil {
|
|
t.Fatalf("failed to unmarshal new file: %v", err)
|
|
}
|
|
|
|
// Determine expected content
|
|
var expectedContent map[string]*funnelClient
|
|
if tt.setupNewFile {
|
|
expectedContent = tt.newFileContent
|
|
} else if tt.setupOldFile {
|
|
expectedContent = tt.oldFileContent
|
|
} else {
|
|
expectedContent = make(map[string]*funnelClient)
|
|
}
|
|
|
|
if len(clients) != len(expectedContent) {
|
|
t.Errorf("expected %d clients, got %d", len(expectedContent), len(clients))
|
|
}
|
|
|
|
for id, expectedClient := range expectedContent {
|
|
actualClient, ok := clients[id]
|
|
if !ok {
|
|
t.Errorf("expected client %s not found", id)
|
|
continue
|
|
}
|
|
if actualClient.ID != expectedClient.ID ||
|
|
actualClient.Secret != expectedClient.Secret ||
|
|
actualClient.Name != expectedClient.Name ||
|
|
actualClient.RedirectURI != expectedClient.RedirectURI {
|
|
t.Errorf("client %s mismatch: got %+v, want %+v", id, actualClient, expectedClient)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Verify old file renamed if expected
|
|
if tt.expectOldRenamed {
|
|
deprecatedPath := rootPath + "/" + deprecatedFunnelClientsFile
|
|
if _, err := os.Stat(deprecatedPath); err != nil {
|
|
t.Errorf("expected old file to be renamed to %s: %v", deprecatedPath, err)
|
|
}
|
|
|
|
// Verify original old file is gone
|
|
oldPath := rootPath + "/" + funnelClientsFile
|
|
if _, err := os.Stat(oldPath); !os.IsNotExist(err) {
|
|
t.Errorf("expected old file %s to be removed", oldPath)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestGetConfigFilePath verifies backward compatibility for config file location.
|
|
// The function must check current directory first (legacy deployments) before
|
|
// falling back to rootPath (new installations) to prevent breaking existing
|
|
// tsidp deployments that have config files in unexpected locations.
|
|
func TestGetConfigFilePath(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
fileName string
|
|
createInCwd bool
|
|
createInRoot bool
|
|
expectInCwd bool
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "file exists in current directory - use current directory",
|
|
fileName: "test-config.json",
|
|
createInCwd: true,
|
|
expectInCwd: true,
|
|
},
|
|
{
|
|
name: "file does not exist - use root path",
|
|
fileName: "test-config.json",
|
|
createInCwd: false,
|
|
expectInCwd: false,
|
|
},
|
|
{
|
|
name: "file exists in both - prefer current directory",
|
|
fileName: "test-config.json",
|
|
createInCwd: true,
|
|
createInRoot: true,
|
|
expectInCwd: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create temporary directories
|
|
rootPath := t.TempDir()
|
|
originalWd, err := os.Getwd()
|
|
if err != nil {
|
|
t.Fatalf("failed to get working directory: %v", err)
|
|
}
|
|
|
|
// Create a temporary working directory
|
|
tmpWd := t.TempDir()
|
|
if err := os.Chdir(tmpWd); err != nil {
|
|
t.Fatalf("failed to change to temp directory: %v", err)
|
|
}
|
|
defer func() {
|
|
os.Chdir(originalWd)
|
|
}()
|
|
|
|
// Setup files as needed
|
|
if tt.createInCwd {
|
|
if err := os.WriteFile(tt.fileName, []byte("{}"), 0600); err != nil {
|
|
t.Fatalf("failed to create file in cwd: %v", err)
|
|
}
|
|
}
|
|
if tt.createInRoot {
|
|
rootFilePath := filepath.Join(rootPath, tt.fileName)
|
|
if err := os.WriteFile(rootFilePath, []byte("{}"), 0600); err != nil {
|
|
t.Fatalf("failed to create file in root: %v", err)
|
|
}
|
|
}
|
|
|
|
// Call getConfigFilePath
|
|
resultPath, err := getConfigFilePath(rootPath, tt.fileName)
|
|
|
|
if tt.expectError && err == nil {
|
|
t.Fatalf("expected error but got none")
|
|
}
|
|
if !tt.expectError && err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if tt.expectError {
|
|
return
|
|
}
|
|
|
|
// Verify result
|
|
if tt.expectInCwd {
|
|
if resultPath != tt.fileName {
|
|
t.Errorf("expected path %s, got %s", tt.fileName, resultPath)
|
|
}
|
|
} else {
|
|
expectedPath := filepath.Join(rootPath, tt.fileName)
|
|
if resultPath != expectedPath {
|
|
t.Errorf("expected path %s, got %s", expectedPath, resultPath)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAuthorizeStrictMode verifies OAuth authorization endpoint security and validation logic.
|
|
// Tests both the security boundary (funnel rejection) and the business logic (strict mode validation).
|
|
func TestAuthorizeStrictMode(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
clientID string
|
|
redirectURI string
|
|
state string
|
|
nonce string
|
|
setupClient bool
|
|
clientRedirect string
|
|
useFunnel bool // whether to simulate funnel request
|
|
mockWhoIsError bool // whether to make WhoIs return an error
|
|
expectError bool
|
|
expectCode int
|
|
expectRedirect bool
|
|
}{
|
|
// Security boundary test: funnel rejection
|
|
{
|
|
name: "funnel requests are always rejected for security",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
state: "random-state",
|
|
nonce: "random-nonce",
|
|
setupClient: true,
|
|
clientRedirect: "https://rp.example.com/callback",
|
|
useFunnel: true,
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
|
|
// Strict mode parameter validation tests (non-funnel)
|
|
{
|
|
name: "strict mode - missing client_id",
|
|
strictMode: true,
|
|
clientID: "",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - missing redirect_uri",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
|
|
// Strict mode client validation tests (non-funnel)
|
|
{
|
|
name: "strict mode - invalid client_id",
|
|
strictMode: true,
|
|
clientID: "invalid-client",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupClient: false,
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - redirect_uri mismatch",
|
|
strictMode: true,
|
|
clientID: "test-client",
|
|
redirectURI: "https://wrong.example.com/callback",
|
|
setupClient: true,
|
|
clientRedirect: "https://rp.example.com/callback",
|
|
useFunnel: false,
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// For non-funnel tests, we'll test the parameter validation logic
|
|
// without needing to mock WhoIs, since the validation happens before WhoIs calls
|
|
|
|
// Setup client if needed
|
|
if tt.setupClient {
|
|
srv.funnelClients["test-client"] = &funnelClient{
|
|
ID: "test-client",
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: tt.clientRedirect,
|
|
}
|
|
} else if !tt.strictMode {
|
|
// For non-strict mode tests that don't need a specific client setup
|
|
// but might reference one, clear the default client
|
|
delete(srv.funnelClients, "test-client")
|
|
}
|
|
|
|
// Create request
|
|
reqURL := "/authorize"
|
|
if !tt.strictMode {
|
|
// In non-strict mode, use the node-specific endpoint
|
|
reqURL = "/authorize/123"
|
|
}
|
|
|
|
query := url.Values{}
|
|
if tt.clientID != "" {
|
|
query.Set("client_id", tt.clientID)
|
|
}
|
|
if tt.redirectURI != "" {
|
|
query.Set("redirect_uri", tt.redirectURI)
|
|
}
|
|
if tt.state != "" {
|
|
query.Set("state", tt.state)
|
|
}
|
|
if tt.nonce != "" {
|
|
query.Set("nonce", tt.nonce)
|
|
}
|
|
|
|
reqURL += "?" + query.Encode()
|
|
req := httptest.NewRequest("GET", reqURL, nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
// Set funnel header only when explicitly testing funnel behavior
|
|
if tt.useFunnel {
|
|
req.Header.Set("Tailscale-Funnel-Request", "true")
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.authorize(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectRedirect {
|
|
if rr.Code != http.StatusFound {
|
|
t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
location := rr.Header().Get("Location")
|
|
if location == "" {
|
|
t.Error("expected Location header in redirect response")
|
|
} else {
|
|
// Parse the redirect URL to verify it contains a code
|
|
redirectURL, err := url.Parse(location)
|
|
if err != nil {
|
|
t.Errorf("failed to parse redirect URL: %v", err)
|
|
} else {
|
|
code := redirectURL.Query().Get("code")
|
|
if code == "" {
|
|
t.Error("expected 'code' parameter in redirect URL")
|
|
}
|
|
|
|
// Verify state is preserved if provided
|
|
if tt.state != "" {
|
|
returnedState := redirectURL.Query().Get("state")
|
|
if returnedState != tt.state {
|
|
t.Errorf("expected state '%s', got '%s'", tt.state, returnedState)
|
|
}
|
|
}
|
|
|
|
// Verify the auth request was stored
|
|
srv.mu.Lock()
|
|
ar, ok := srv.code[code]
|
|
srv.mu.Unlock()
|
|
|
|
if !ok {
|
|
t.Error("expected authorization request to be stored")
|
|
} else {
|
|
if ar.clientID != tt.clientID {
|
|
t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.clientID)
|
|
}
|
|
if ar.redirectURI != tt.redirectURI {
|
|
t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.redirectURI)
|
|
}
|
|
if ar.nonce != tt.nonce {
|
|
t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.nonce)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
t.Errorf("unexpected test case: not expecting error or redirect")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeTokenWithClientValidation verifies OAuth token endpoint security in both strict and non-strict modes.
|
|
// In strict mode, the token endpoint must:
|
|
// - Require and validate client credentials (client_id + client_secret)
|
|
// - Only accept tokens from registered funnel clients
|
|
// - Validate that redirect_uri matches the registered client
|
|
// - Support both form-based and HTTP Basic authentication for client credentials
|
|
func TestServeTokenWithClientValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
method string
|
|
grantType string
|
|
code string
|
|
clientID string
|
|
clientSecret string
|
|
redirectURI string
|
|
useBasicAuth bool
|
|
setupAuthRequest bool
|
|
authRequestClient string
|
|
authRequestRedirect string
|
|
expectError bool
|
|
expectCode int
|
|
expectIDToken bool
|
|
}{
|
|
{
|
|
name: "strict mode - valid token exchange with form credentials",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
{
|
|
name: "strict mode - valid token exchange with basic auth",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
useBasicAuth: true,
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
{
|
|
name: "strict mode - missing client credentials",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - client_id mismatch",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "wrong-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - invalid client secret",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "wrong-secret",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - redirect_uri mismatch",
|
|
strictMode: true,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
clientID: "test-client",
|
|
clientSecret: "test-secret",
|
|
redirectURI: "https://wrong.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestClient: "test-client",
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "non-strict mode - no client validation required",
|
|
strictMode: false,
|
|
method: "POST",
|
|
grantType: "authorization_code",
|
|
code: "valid-code",
|
|
redirectURI: "https://rp.example.com/callback",
|
|
setupAuthRequest: true,
|
|
authRequestRedirect: "https://rp.example.com/callback",
|
|
expectIDToken: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// Setup authorization request if needed
|
|
if tt.setupAuthRequest {
|
|
now := time.Now()
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tailcfg.PeerCapMap{},
|
|
}
|
|
|
|
var funnelClientPtr *funnelClient
|
|
if tt.strictMode && tt.authRequestClient != "" {
|
|
funnelClientPtr = &funnelClient{
|
|
ID: tt.authRequestClient,
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: tt.authRequestRedirect,
|
|
}
|
|
srv.funnelClients[tt.authRequestClient] = funnelClientPtr
|
|
}
|
|
|
|
srv.code["valid-code"] = &authRequest{
|
|
clientID: tt.authRequestClient,
|
|
nonce: "nonce123",
|
|
redirectURI: tt.authRequestRedirect,
|
|
validTill: now.Add(5 * time.Minute),
|
|
remoteUser: remoteUser,
|
|
localRP: !tt.strictMode,
|
|
funnelRP: funnelClientPtr,
|
|
}
|
|
}
|
|
|
|
// Create form data
|
|
form := url.Values{}
|
|
form.Set("grant_type", tt.grantType)
|
|
form.Set("code", tt.code)
|
|
form.Set("redirect_uri", tt.redirectURI)
|
|
|
|
if !tt.useBasicAuth {
|
|
if tt.clientID != "" {
|
|
form.Set("client_id", tt.clientID)
|
|
}
|
|
if tt.clientSecret != "" {
|
|
form.Set("client_secret", tt.clientSecret)
|
|
}
|
|
}
|
|
|
|
req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode()))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
if tt.useBasicAuth && tt.clientID != "" && tt.clientSecret != "" {
|
|
req.SetBasicAuth(tt.clientID, tt.clientSecret)
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.serveToken(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectIDToken {
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp struct {
|
|
IDToken string `json:"id_token"`
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to unmarshal response: %v", err)
|
|
}
|
|
|
|
if resp.IDToken == "" {
|
|
t.Error("expected id_token in response")
|
|
}
|
|
if resp.AccessToken == "" {
|
|
t.Error("expected access_token in response")
|
|
}
|
|
if resp.TokenType != "Bearer" {
|
|
t.Errorf("expected token_type 'Bearer', got '%s'", resp.TokenType)
|
|
}
|
|
if resp.ExpiresIn != 300 {
|
|
t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn)
|
|
}
|
|
|
|
// Verify access token was stored
|
|
srv.mu.Lock()
|
|
_, ok := srv.accessToken[resp.AccessToken]
|
|
srv.mu.Unlock()
|
|
|
|
if !ok {
|
|
t.Error("expected access token to be stored")
|
|
}
|
|
|
|
// Verify authorization code was consumed
|
|
srv.mu.Lock()
|
|
_, ok = srv.code[tt.code]
|
|
srv.mu.Unlock()
|
|
|
|
if ok {
|
|
t.Error("expected authorization code to be consumed")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestServeUserInfoWithClientValidation verifies UserInfo endpoint security in both strict and non-strict modes.
|
|
// In strict mode, the UserInfo endpoint must:
|
|
// - Validate that access tokens are associated with registered clients
|
|
// - Reject tokens for clients that have been deleted/unregistered
|
|
// - Enforce token expiration properly
|
|
// - Return appropriate user claims based on client capabilities
|
|
func TestServeUserInfoWithClientValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
strictMode bool
|
|
setupToken bool
|
|
setupClient bool
|
|
clientID string
|
|
token string
|
|
tokenValidTill time.Time
|
|
expectError bool
|
|
expectCode int
|
|
expectUserInfo bool
|
|
}{
|
|
{
|
|
name: "strict mode - valid token with existing client",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: true,
|
|
clientID: "test-client",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectUserInfo: true,
|
|
},
|
|
{
|
|
name: "strict mode - valid token but client no longer exists",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "deleted-client",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "strict mode - expired token",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: true,
|
|
clientID: "test-client",
|
|
token: "expired-token",
|
|
tokenValidTill: time.Now().Add(-5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - invalid token",
|
|
strictMode: true,
|
|
setupToken: false,
|
|
token: "invalid-token",
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "strict mode - token without client association",
|
|
strictMode: true,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectError: true,
|
|
expectCode: http.StatusBadRequest,
|
|
},
|
|
{
|
|
name: "non-strict mode - no client validation required",
|
|
strictMode: false,
|
|
setupToken: true,
|
|
setupClient: false,
|
|
clientID: "",
|
|
token: "valid-token",
|
|
tokenValidTill: time.Now().Add(5 * time.Minute),
|
|
expectUserInfo: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
srv := setupTestServer(t, tt.strictMode)
|
|
|
|
// Setup client if needed
|
|
if tt.setupClient {
|
|
srv.funnelClients[tt.clientID] = &funnelClient{
|
|
ID: tt.clientID,
|
|
Secret: "test-secret",
|
|
Name: "Test Client",
|
|
RedirectURI: "https://rp.example.com/callback",
|
|
}
|
|
}
|
|
|
|
// Setup token if needed
|
|
if tt.setupToken {
|
|
profile := &tailcfg.UserProfile{
|
|
LoginName: "alice@example.com",
|
|
DisplayName: "Alice Example",
|
|
ProfilePicURL: "https://example.com/alice.jpg",
|
|
}
|
|
node := &tailcfg.Node{
|
|
ID: 123,
|
|
Name: "test-node.test.ts.net.",
|
|
User: 456,
|
|
Key: key.NodePublic{},
|
|
Cap: 1,
|
|
DiscoKey: key.DiscoPublic{},
|
|
}
|
|
remoteUser := &apitype.WhoIsResponse{
|
|
Node: node,
|
|
UserProfile: profile,
|
|
CapMap: tailcfg.PeerCapMap{},
|
|
}
|
|
|
|
srv.accessToken[tt.token] = &authRequest{
|
|
clientID: tt.clientID,
|
|
validTill: tt.tokenValidTill,
|
|
remoteUser: remoteUser,
|
|
}
|
|
}
|
|
|
|
// Create request
|
|
req := httptest.NewRequest("GET", "/userinfo", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tt.token)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
rr := httptest.NewRecorder()
|
|
srv.serveUserInfo(rr, req)
|
|
|
|
if tt.expectError {
|
|
if rr.Code != tt.expectCode {
|
|
t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String())
|
|
}
|
|
} else if tt.expectUserInfo {
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String())
|
|
}
|
|
|
|
var resp map[string]any
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
|
|
t.Fatalf("failed to parse JSON response: %v", err)
|
|
}
|
|
|
|
// Check required fields
|
|
expectedFields := []string{"sub", "name", "email", "picture", "username"}
|
|
for _, field := range expectedFields {
|
|
if _, ok := resp[field]; !ok {
|
|
t.Errorf("expected field '%s' in user info response", field)
|
|
}
|
|
}
|
|
|
|
// Verify specific values
|
|
if resp["name"] != "Alice Example" {
|
|
t.Errorf("expected name 'Alice Example', got '%v'", resp["name"])
|
|
}
|
|
if resp["email"] != "alice@example.com" {
|
|
t.Errorf("expected email 'alice@example.com', got '%v'", resp["email"])
|
|
}
|
|
if resp["username"] != "alice" {
|
|
t.Errorf("expected username 'alice', got '%v'", resp["username"])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|