mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 08:01:34 +01:00 
			
		
		
		
	fix oidc test, add tests for migration
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									2fe65624c0
								
							
						
					
					
						commit
						4dd12a2f97
					
				
							
								
								
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/test-integration.yaml
									
									
									
									
										vendored
									
									
								
							| @ -21,6 +21,7 @@ jobs: | |||||||
|           - TestPolicyUpdateWhileRunningWithCLIInDatabase |           - TestPolicyUpdateWhileRunningWithCLIInDatabase | ||||||
|           - TestOIDCAuthenticationPingAll |           - TestOIDCAuthenticationPingAll | ||||||
|           - TestOIDCExpireNodesBasedOnTokenExpiry |           - TestOIDCExpireNodesBasedOnTokenExpiry | ||||||
|  |           - TestOIDC024UserCreation | ||||||
|           - TestAuthWebFlowAuthenticationPingAll |           - TestAuthWebFlowAuthenticationPingAll | ||||||
|           - TestAuthWebFlowLogoutAndRelogin |           - TestAuthWebFlowLogoutAndRelogin | ||||||
|           - TestUserCommand |           - TestUserCommand | ||||||
|  | |||||||
| @ -1,8 +1,10 @@ | |||||||
| package cli | package cli | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
|  | 	"encoding/json" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"net" | 	"net" | ||||||
|  | 	"net/http" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"time" | 	"time" | ||||||
| @ -64,6 +66,19 @@ func mockOIDC() error { | |||||||
| 		accessTTL = newTTL | 		accessTTL = newTTL | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	userStr := os.Getenv("MOCKOIDC_USERS") | ||||||
|  | 	if userStr == "" { | ||||||
|  | 		return fmt.Errorf("MOCKOIDC_USERS not defined") | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	var users []mockoidc.MockUser | ||||||
|  | 	err := json.Unmarshal([]byte(userStr), &users) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("unmarshalling users: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	log.Info().Interface("users", users).Msg("loading users from JSON") | ||||||
|  | 
 | ||||||
| 	log.Info().Msgf("Access token TTL: %s", accessTTL) | 	log.Info().Msgf("Access token TTL: %s", accessTTL) | ||||||
| 
 | 
 | ||||||
| 	port, err := strconv.Atoi(portStr) | 	port, err := strconv.Atoi(portStr) | ||||||
| @ -71,7 +86,7 @@ func mockOIDC() error { | |||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	mock, err := getMockOIDC(clientID, clientSecret) | 	mock, err := getMockOIDC(clientID, clientSecret, users) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @ -93,12 +108,18 @@ func mockOIDC() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { | func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { | ||||||
| 	keypair, err := mockoidc.NewKeypair(nil) | 	keypair, err := mockoidc.NewKeypair(nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	userQueue := mockoidc.UserQueue{} | ||||||
|  | 
 | ||||||
|  | 	for _, user := range users { | ||||||
|  | 		userQueue.Push(&user) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	mock := mockoidc.MockOIDC{ | 	mock := mockoidc.MockOIDC{ | ||||||
| 		ClientID:                      clientID, | 		ClientID:                      clientID, | ||||||
| 		ClientSecret:                  clientSecret, | 		ClientSecret:                  clientSecret, | ||||||
| @ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro | |||||||
| 		CodeChallengeMethodsSupported: []string{"plain", "S256"}, | 		CodeChallengeMethodsSupported: []string{"plain", "S256"}, | ||||||
| 		Keypair:                       keypair, | 		Keypair:                       keypair, | ||||||
| 		SessionStore:                  mockoidc.NewSessionStore(), | 		SessionStore:                  mockoidc.NewSessionStore(), | ||||||
| 		UserQueue:                     &mockoidc.UserQueue{}, | 		UserQueue:                     &userQueue, | ||||||
| 		ErrorQueue:                    &mockoidc.ErrorQueue{}, | 		ErrorQueue:                    &mockoidc.ErrorQueue{}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	mock.AddMiddleware(func(h http.Handler) http.Handler { | ||||||
|  | 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 			log.Info().Msgf("Request: %+v", r) | ||||||
|  | 			h.ServeHTTP(w, r) | ||||||
|  | 			if r.Response != nil { | ||||||
|  | 				log.Info().Msgf("Response: %+v", r.Response) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
| 	return &mock, nil | 	return &mock, nil | ||||||
| } | } | ||||||
|  | |||||||
| @ -436,7 +436,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( | |||||||
| ) (*types.User, error) { | ) (*types.User, error) { | ||||||
| 	var user *types.User | 	var user *types.User | ||||||
| 	var err error | 	var err error | ||||||
| 	user, err = a.db.GetUserByOIDCIdentifier(claims.Sub) | 	user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier()) | ||||||
| 	if err != nil && !errors.Is(err, db.ErrUserNotFound) { | 	if err != nil && !errors.Is(err, db.ErrUserNotFound) { | ||||||
| 		return nil, fmt.Errorf("creating or updating user: %w", err) | 		return nil, fmt.Errorf("creating or updating user: %w", err) | ||||||
| 	} | 	} | ||||||
| @ -448,10 +448,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( | |||||||
| 	// TODO(kradalby): Remove when strip_email_domain and migration is removed | 	// TODO(kradalby): Remove when strip_email_domain and migration is removed | ||||||
| 	// after #2170 is cleaned up. | 	// after #2170 is cleaned up. | ||||||
| 	if a.cfg.MapLegacyUsers && user == nil { | 	if a.cfg.MapLegacyUsers && user == nil { | ||||||
|  | 		log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username") | ||||||
| 		if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { | 		if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil { | ||||||
|  | 			log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username") | ||||||
| 			user, err = a.db.GetUserByName(oldUsername) | 			user, err = a.db.GetUserByName(oldUsername) | ||||||
| 			if err != nil && !errors.Is(err, db.ErrUserNotFound) { | 			if err != nil && !errors.Is(err, db.ErrUserNotFound) { | ||||||
| 				return nil, fmt.Errorf("creating or updating user: %w", err) | 				return nil, fmt.Errorf("getting user: %w", err) | ||||||
| 			} | 			} | ||||||
| 
 | 
 | ||||||
| 			// If the user exists, but it already has a provider identifier (OIDC sub), create a new user. | 			// If the user exists, but it already has a provider identifier (OIDC sub), create a new user. | ||||||
| @ -525,6 +527,9 @@ func getUserName( | |||||||
| 	claims *types.OIDCClaims, | 	claims *types.OIDCClaims, | ||||||
| 	stripEmaildomain bool, | 	stripEmaildomain bool, | ||||||
| ) (string, error) { | ) (string, error) { | ||||||
|  | 	if !claims.EmailVerified { | ||||||
|  | 		return "", fmt.Errorf("email not verified") | ||||||
|  | 	} | ||||||
| 	userName, err := util.NormalizeToFQDNRules( | 	userName, err := util.NormalizeToFQDNRules( | ||||||
| 		claims.Email, | 		claims.Email, | ||||||
| 		stripEmaildomain, | 		stripEmaildomain, | ||||||
|  | |||||||
| @ -908,6 +908,9 @@ func LoadServerConfig() (*Config, error) { | |||||||
| 				} | 				} | ||||||
| 			}(), | 			}(), | ||||||
| 			UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), | 			UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), | ||||||
|  | 			// TODO(kradalby): Remove when strip_email_domain is removed | ||||||
|  | 			// after #2170 is cleaned up | ||||||
|  | 			StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), | ||||||
| 			MapLegacyUsers:   viper.GetBool("oidc.map_legacy_users"), | 			MapLegacyUsers:   viper.GetBool("oidc.map_legacy_users"), | ||||||
| 		}, | 		}, | ||||||
| 
 | 
 | ||||||
|  | |||||||
| @ -3,7 +3,6 @@ package types | |||||||
| import ( | import ( | ||||||
| 	"cmp" | 	"cmp" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"strings" |  | ||||||
| 
 | 
 | ||||||
| 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| @ -39,7 +38,7 @@ type User struct { | |||||||
| 	// Unique identifier of the user from OIDC, | 	// Unique identifier of the user from OIDC, | ||||||
| 	// comes from `sub` claim in the OIDC token | 	// comes from `sub` claim in the OIDC token | ||||||
| 	// and is used to lookup the user. | 	// and is used to lookup the user. | ||||||
| 	ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` | 	ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` | ||||||
| 
 | 
 | ||||||
| 	// Provider is the origin of the user account, | 	// Provider is the origin of the user account, | ||||||
| 	// same as RegistrationMethod, without authkey. | 	// same as RegistrationMethod, without authkey. | ||||||
| @ -58,9 +57,10 @@ type User struct { | |||||||
| // If the username does not contain an '@' it will be added to the end. | // If the username does not contain an '@' it will be added to the end. | ||||||
| func (u *User) Username() string { | func (u *User) Username() string { | ||||||
| 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) | 	username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) | ||||||
| 	if !strings.Contains(username, "@") { | 	// TODO(kradalby): Wire up all of this for the future | ||||||
| 		username = username + "@" | 	// if !strings.Contains(username, "@") { | ||||||
| 	} | 	// 	username = username + "@" | ||||||
|  | 	// } | ||||||
| 
 | 
 | ||||||
| 	return username | 	return username | ||||||
| } | } | ||||||
| @ -138,10 +138,14 @@ type OIDCClaims struct { | |||||||
| 	Username          string   `json:"preferred_username,omitempty"` | 	Username          string   `json:"preferred_username,omitempty"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | func (c *OIDCClaims) Identifier() string { | ||||||
|  | 	return c.Iss + "/" + c.Sub | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // FromClaim overrides a User from OIDC claims. | // FromClaim overrides a User from OIDC claims. | ||||||
| // All fields will be updated, except for the ID. | // All fields will be updated, except for the ID. | ||||||
| func (u *User) FromClaim(claims *OIDCClaims) { | func (u *User) FromClaim(claims *OIDCClaims) { | ||||||
| 	u.ProviderIdentifier = claims.Iss + "/" + claims.Sub | 	u.ProviderIdentifier = claims.Identifier() | ||||||
| 	u.DisplayName = claims.Name | 	u.DisplayName = claims.Name | ||||||
| 	if claims.EmailVerified { | 	if claims.EmailVerified { | ||||||
| 		u.Email = claims.Email | 		u.Email = claims.Email | ||||||
|  | |||||||
| @ -3,6 +3,7 @@ package integration | |||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
|  | 	"encoding/json" | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"io" | 	"io" | ||||||
| @ -10,14 +11,19 @@ import ( | |||||||
| 	"net" | 	"net" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 	"net/netip" | 	"net/netip" | ||||||
|  | 	"sort" | ||||||
| 	"strconv" | 	"strconv" | ||||||
| 	"testing" | 	"testing" | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
|  | 	"github.com/google/go-cmp/cmp" | ||||||
|  | 	"github.com/google/go-cmp/cmp/cmpopts" | ||||||
|  | 	v1 "github.com/juanfont/headscale/gen/go/headscale/v1" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/types" | 	"github.com/juanfont/headscale/hscontrol/types" | ||||||
| 	"github.com/juanfont/headscale/hscontrol/util" | 	"github.com/juanfont/headscale/hscontrol/util" | ||||||
| 	"github.com/juanfont/headscale/integration/dockertestutil" | 	"github.com/juanfont/headscale/integration/dockertestutil" | ||||||
| 	"github.com/juanfont/headscale/integration/hsic" | 	"github.com/juanfont/headscale/integration/hsic" | ||||||
|  | 	"github.com/oauth2-proxy/mockoidc" | ||||||
| 	"github.com/ory/dockertest/v3" | 	"github.com/ory/dockertest/v3" | ||||||
| 	"github.com/ory/dockertest/v3/docker" | 	"github.com/ory/dockertest/v3/docker" | ||||||
| 	"github.com/samber/lo" | 	"github.com/samber/lo" | ||||||
| @ -50,18 +56,32 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | |||||||
| 	} | 	} | ||||||
| 	defer scenario.ShutdownAssertNoPanics(t) | 	defer scenario.ShutdownAssertNoPanics(t) | ||||||
| 
 | 
 | ||||||
|  | 	// Logins to MockOIDC is served by a queue with a strict order, | ||||||
|  | 	// if we use more than one node per user, the order of the logins | ||||||
|  | 	// will not be deterministic and the test will fail. | ||||||
| 	spec := map[string]int{ | 	spec := map[string]int{ | ||||||
| 		"user1": len(MustTestVersions), | 		"user1": 1, | ||||||
|  | 		"user2": 1, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) | 	mockusers := []mockoidc.MockUser{ | ||||||
|  | 		oidcMockUser("user1", true), | ||||||
|  | 		oidcMockUser("user2", false), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||||
|  | 	defer scenario.mockOIDC.Close() | ||||||
| 
 | 
 | ||||||
| 	oidcMap := map[string]string{ | 	oidcMap := map[string]string{ | ||||||
| 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | 		"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||||
| 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | 		"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||||
| 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | 		"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||||
| 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | 		"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||||
|  | 		// TODO(kradalby): Remove when strip_email_domain is removed | ||||||
|  | 		// after #2170 is cleaned up | ||||||
|  | 		"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "0", | ||||||
|  | 		"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	err = scenario.CreateHeadscaleEnv( | 	err = scenario.CreateHeadscaleEnv( | ||||||
| @ -91,6 +111,55 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { | |||||||
| 
 | 
 | ||||||
| 	success := pingAllHelper(t, allClients, allAddrs) | 	success := pingAllHelper(t, allClients, allAddrs) | ||||||
| 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | 	t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) | ||||||
|  | 
 | ||||||
|  | 	headscale, err := scenario.Headscale() | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 
 | ||||||
|  | 	var listUsers []v1.User | ||||||
|  | 	err = executeAndUnmarshal(headscale, | ||||||
|  | 		[]string{ | ||||||
|  | 			"headscale", | ||||||
|  | 			"users", | ||||||
|  | 			"list", | ||||||
|  | 			"--output", | ||||||
|  | 			"json", | ||||||
|  | 		}, | ||||||
|  | 		&listUsers, | ||||||
|  | 	) | ||||||
|  | 	assertNoErr(t, err) | ||||||
|  | 
 | ||||||
|  | 	want := []v1.User{ | ||||||
|  | 		{ | ||||||
|  | 			Id:   "1", | ||||||
|  | 			Name: "user1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "2", | ||||||
|  | 			Name:       "user1", | ||||||
|  | 			Email:      "user1@headscale.net", | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user1", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:   "3", | ||||||
|  | 			Name: "user2", | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			Id:         "4", | ||||||
|  | 			Name:       "user2", | ||||||
|  | 			Email:      "", // Unverified | ||||||
|  | 			Provider:   "oidc", | ||||||
|  | 			ProviderId: oidcConfig.Issuer + "/user2", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	sort.Slice(listUsers, func(i, j int) bool { | ||||||
|  | 		return listUsers[i].Id < listUsers[j].Id | ||||||
|  | 	}) | ||||||
|  | 
 | ||||||
|  | 	if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||||
|  | 		t.Fatalf("unexpected users: %s", diff) | ||||||
|  | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This test is really flaky. | // This test is really flaky. | ||||||
| @ -111,11 +180,16 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { | |||||||
| 	defer scenario.ShutdownAssertNoPanics(t) | 	defer scenario.ShutdownAssertNoPanics(t) | ||||||
| 
 | 
 | ||||||
| 	spec := map[string]int{ | 	spec := map[string]int{ | ||||||
| 		"user1": 3, | 		"user1": 1, | ||||||
|  | 		"user2": 1, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) | 	oidcConfig, err := scenario.runMockOIDC(shortAccessTTL, []mockoidc.MockUser{ | ||||||
|  | 		oidcMockUser("user1", true), | ||||||
|  | 		oidcMockUser("user2", false), | ||||||
|  | 	}) | ||||||
| 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | 	assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||||
|  | 	defer scenario.mockOIDC.Close() | ||||||
| 
 | 
 | ||||||
| 	oidcMap := map[string]string{ | 	oidcMap := map[string]string{ | ||||||
| 		"HEADSCALE_OIDC_ISSUER":                oidcConfig.Issuer, | 		"HEADSCALE_OIDC_ISSUER":                oidcConfig.Issuer, | ||||||
| @ -159,6 +233,297 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { | |||||||
| 	assertTailscaleNodesLogout(t, allClients) | 	assertTailscaleNodesLogout(t, allClients) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // TODO(kradalby): | ||||||
|  | // - Test that creates a new user when one exists when migration is turned off | ||||||
|  | // - Test that takes over a user when one exists when migration is turned on | ||||||
|  | //   - But email is not verified | ||||||
|  | //   - stripped email domain on/off | ||||||
|  | func TestOIDC024UserCreation(t *testing.T) { | ||||||
|  | 	IntegrationSkip(t) | ||||||
|  | 
 | ||||||
|  | 	tests := []struct { | ||||||
|  | 		name          string | ||||||
|  | 		config        map[string]string | ||||||
|  | 		emailVerified bool | ||||||
|  | 		cliUsers      []string | ||||||
|  | 		oidcUsers     []string | ||||||
|  | 		want          func(iss string) []v1.User | ||||||
|  | 	}{ | ||||||
|  | 		{ | ||||||
|  | 			name: "no-migration-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: true, | ||||||
|  | 			cliUsers:      []string{"user1", "user2"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					{ | ||||||
|  | 						Id:   "1", | ||||||
|  | 						Name: "user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Email:      "user1@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:   "3", | ||||||
|  | 						Name: "user2", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "4", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Email:      "user2@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "no-migration-not-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS": "0", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: false, | ||||||
|  | 			cliUsers:      []string{"user1", "user2"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					{ | ||||||
|  | 						Id:   "1", | ||||||
|  | 						Name: "user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:   "3", | ||||||
|  | 						Name: "user2", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "4", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "migration-strip-domains-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||||
|  | 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: true, | ||||||
|  | 			cliUsers:      []string{"user1", "user2"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					{ | ||||||
|  | 						Id:         "1", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Email:      "user1@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Email:      "user2@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "migration-strip-domains-not-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||||
|  | 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "1", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: false, | ||||||
|  | 			cliUsers:      []string{"user1", "user2"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					{ | ||||||
|  | 						Id:   "1", | ||||||
|  | 						Name: "user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:   "3", | ||||||
|  | 						Name: "user2", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "4", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "migration-no-strip-domains-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||||
|  | 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: true, | ||||||
|  | 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					// Hmm I think we will have to overwrite the initial name here | ||||||
|  | 					// createuser with "user1.headscale.net", but oidc with "user1" | ||||||
|  | 					{ | ||||||
|  | 						Id:         "1", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Email:      "user1@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Email:      "user2@headscale.net", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			name: "migration-no-strip-domains-not-verified-email", | ||||||
|  | 			config: map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_MAP_LEGACY_USERS":   "1", | ||||||
|  | 				"HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": "0", | ||||||
|  | 			}, | ||||||
|  | 			emailVerified: false, | ||||||
|  | 			cliUsers:      []string{"user1.headscale.net", "user2.headscale.net"}, | ||||||
|  | 			oidcUsers:     []string{"user1", "user2"}, | ||||||
|  | 			want: func(iss string) []v1.User { | ||||||
|  | 				return []v1.User{ | ||||||
|  | 					{ | ||||||
|  | 						Id:   "1", | ||||||
|  | 						Name: "user1.headscale.net", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "2", | ||||||
|  | 						Name:       "user1", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user1", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:   "3", | ||||||
|  | 						Name: "user2.headscale.net", | ||||||
|  | 					}, | ||||||
|  | 					{ | ||||||
|  | 						Id:         "4", | ||||||
|  | 						Name:       "user2", | ||||||
|  | 						Provider:   "oidc", | ||||||
|  | 						ProviderId: iss + "/user2", | ||||||
|  | 					}, | ||||||
|  | 				} | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	for _, tt := range tests { | ||||||
|  | 		t.Run(tt.name, func(t *testing.T) { | ||||||
|  | 			baseScenario, err := NewScenario(dockertestMaxWait()) | ||||||
|  | 			assertNoErr(t, err) | ||||||
|  | 
 | ||||||
|  | 			scenario := AuthOIDCScenario{ | ||||||
|  | 				Scenario: baseScenario, | ||||||
|  | 			} | ||||||
|  | 			defer scenario.ShutdownAssertNoPanics(t) | ||||||
|  | 
 | ||||||
|  | 			spec := map[string]int{} | ||||||
|  | 			for _, user := range tt.cliUsers { | ||||||
|  | 				spec[user] = 1 | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			var mockusers []mockoidc.MockUser | ||||||
|  | 			for _, user := range tt.oidcUsers { | ||||||
|  | 				mockusers = append(mockusers, oidcMockUser(user, tt.emailVerified)) | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL, mockusers) | ||||||
|  | 			assertNoErrf(t, "failed to run mock OIDC server: %s", err) | ||||||
|  | 			defer scenario.mockOIDC.Close() | ||||||
|  | 
 | ||||||
|  | 			oidcMap := map[string]string{ | ||||||
|  | 				"HEADSCALE_OIDC_ISSUER":             oidcConfig.Issuer, | ||||||
|  | 				"HEADSCALE_OIDC_CLIENT_ID":          oidcConfig.ClientID, | ||||||
|  | 				"CREDENTIALS_DIRECTORY_TEST":        "/tmp", | ||||||
|  | 				"HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			for k, v := range tt.config { | ||||||
|  | 				oidcMap[k] = v | ||||||
|  | 			} | ||||||
|  | 
 | ||||||
|  | 			err = scenario.CreateHeadscaleEnv( | ||||||
|  | 				spec, | ||||||
|  | 				hsic.WithTestName("oidcmigration"), | ||||||
|  | 				hsic.WithConfigEnv(oidcMap), | ||||||
|  | 				hsic.WithTLS(), | ||||||
|  | 				hsic.WithHostnameAsServerURL(), | ||||||
|  | 				hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), | ||||||
|  | 			) | ||||||
|  | 			assertNoErrHeadscaleEnv(t, err) | ||||||
|  | 
 | ||||||
|  | 			// Ensure that the nodes have logged in, this is what | ||||||
|  | 			// triggers user creation via OIDC. | ||||||
|  | 			err = scenario.WaitForTailscaleSync() | ||||||
|  | 			assertNoErrSync(t, err) | ||||||
|  | 
 | ||||||
|  | 			headscale, err := scenario.Headscale() | ||||||
|  | 			assertNoErr(t, err) | ||||||
|  | 
 | ||||||
|  | 			want := tt.want(oidcConfig.Issuer) | ||||||
|  | 
 | ||||||
|  | 			var listUsers []v1.User | ||||||
|  | 			err = executeAndUnmarshal(headscale, | ||||||
|  | 				[]string{ | ||||||
|  | 					"headscale", | ||||||
|  | 					"users", | ||||||
|  | 					"list", | ||||||
|  | 					"--output", | ||||||
|  | 					"json", | ||||||
|  | 				}, | ||||||
|  | 				&listUsers, | ||||||
|  | 			) | ||||||
|  | 			assertNoErr(t, err) | ||||||
|  | 
 | ||||||
|  | 			sort.Slice(listUsers, func(i, j int) bool { | ||||||
|  | 				return listUsers[i].Id < listUsers[j].Id | ||||||
|  | 			}) | ||||||
|  | 
 | ||||||
|  | 			if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { | ||||||
|  | 				t.Errorf("unexpected users: %s", diff) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
| func (s *AuthOIDCScenario) CreateHeadscaleEnv( | func (s *AuthOIDCScenario) CreateHeadscaleEnv( | ||||||
| 	users map[string]int, | 	users map[string]int, | ||||||
| 	opts ...hsic.Option, | 	opts ...hsic.Option, | ||||||
| @ -174,6 +539,13 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for userName, clientCount := range users { | 	for userName, clientCount := range users { | ||||||
|  | 		if clientCount != 1 { | ||||||
|  | 			// OIDC scenario only supports one client per user. | ||||||
|  | 			// This is because the MockOIDC server can only serve login | ||||||
|  | 			// requests based on a queue it has been given on startup. | ||||||
|  | 			// We currently only populates it with one login request per user. | ||||||
|  | 			return fmt.Errorf("client count must be 1 for OIDC scenario.") | ||||||
|  | 		} | ||||||
| 		log.Printf("creating user %s with %d clients", userName, clientCount) | 		log.Printf("creating user %s with %d clients", userName, clientCount) | ||||||
| 		err = s.CreateUser(userName) | 		err = s.CreateUser(userName) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @ -194,7 +566,7 @@ func (s *AuthOIDCScenario) CreateHeadscaleEnv( | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { | func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) (*types.OIDCConfig, error) { | ||||||
| 	port, err := dockertestutil.RandomFreeHostPort() | 	port, err := dockertestutil.RandomFreeHostPort() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Fatalf("could not find an open port: %s", err) | 		log.Fatalf("could not find an open port: %s", err) | ||||||
| @ -205,6 +577,11 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf | |||||||
| 
 | 
 | ||||||
| 	hostname := fmt.Sprintf("hs-oidcmock-%s", hash) | 	hostname := fmt.Sprintf("hs-oidcmock-%s", hash) | ||||||
| 
 | 
 | ||||||
|  | 	usersJSON, err := json.Marshal(users) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	mockOidcOptions := &dockertest.RunOptions{ | 	mockOidcOptions := &dockertest.RunOptions{ | ||||||
| 		Name:         hostname, | 		Name:         hostname, | ||||||
| 		Cmd:          []string{"headscale", "mockoidc"}, | 		Cmd:          []string{"headscale", "mockoidc"}, | ||||||
| @ -219,6 +596,7 @@ func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConf | |||||||
| 			"MOCKOIDC_CLIENT_ID=superclient", | 			"MOCKOIDC_CLIENT_ID=superclient", | ||||||
| 			"MOCKOIDC_CLIENT_SECRET=supersecret", | 			"MOCKOIDC_CLIENT_SECRET=supersecret", | ||||||
| 			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), | 			fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), | ||||||
|  | 			fmt.Sprintf("MOCKOIDC_USERS=%s", string(usersJSON)), | ||||||
| 		}, | 		}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| @ -310,7 +688,6 @@ func (s *AuthOIDCScenario) runTailscaleUp( | |||||||
| 
 | 
 | ||||||
| 				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) | 				log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) | ||||||
| 
 | 
 | ||||||
| 				if err := s.pool.Retry(func() error { |  | ||||||
| 				log.Printf("%s logging in with url", c.Hostname()) | 				log.Printf("%s logging in with url", c.Hostname()) | ||||||
| 				httpClient := &http.Client{Transport: insecureTransport} | 				httpClient := &http.Client{Transport: insecureTransport} | ||||||
| 				ctx := context.Background() | 				ctx := context.Background() | ||||||
| @ -329,6 +706,8 @@ func (s *AuthOIDCScenario) runTailscaleUp( | |||||||
| 
 | 
 | ||||||
| 				if resp.StatusCode != http.StatusOK { | 				if resp.StatusCode != http.StatusOK { | ||||||
| 					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | 					log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) | ||||||
|  | 					body, _ := io.ReadAll(resp.Body) | ||||||
|  | 					log.Printf("body: %s", body) | ||||||
| 
 | 
 | ||||||
| 					return errStatusCodeNotOK | 					return errStatusCodeNotOK | ||||||
| 				} | 				} | ||||||
| @ -342,13 +721,7 @@ func (s *AuthOIDCScenario) runTailscaleUp( | |||||||
| 					return err | 					return err | ||||||
| 				} | 				} | ||||||
| 
 | 
 | ||||||
| 					return nil |  | ||||||
| 				}); err != nil { |  | ||||||
| 					return err |  | ||||||
| 				} |  | ||||||
| 
 |  | ||||||
| 				log.Printf("Finished request for %s to join tailnet", c.Hostname()) | 				log.Printf("Finished request for %s to join tailnet", c.Hostname()) | ||||||
| 
 |  | ||||||
| 				return nil | 				return nil | ||||||
| 			}) | 			}) | ||||||
| 
 | 
 | ||||||
| @ -395,3 +768,12 @@ func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { | |||||||
| 		assert.Equal(t, "NeedsLogin", status.BackendState) | 		assert.Equal(t, "NeedsLogin", status.BackendState) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { | ||||||
|  | 	return mockoidc.MockUser{ | ||||||
|  | 		Subject:           username, | ||||||
|  | 		PreferredUsername: username, | ||||||
|  | 		Email:             fmt.Sprintf("%s@headscale.net", username), | ||||||
|  | 		EmailVerified:     emailVerified, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | |||||||
| @ -74,7 +74,7 @@ func ExecuteCommand( | |||||||
| 	select { | 	select { | ||||||
| 	case res := <-resultChan: | 	case res := <-resultChan: | ||||||
| 		if res.err != nil { | 		if res.err != nil { | ||||||
| 			return stdout.String(), stderr.String(), res.err | 			return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if res.exitCode != 0 { | 		if res.exitCode != 0 { | ||||||
| @ -83,12 +83,12 @@ func ExecuteCommand( | |||||||
| 			// log.Println("stdout: ", stdout.String()) | 			// log.Println("stdout: ", stdout.String()) | ||||||
| 			// log.Println("stderr: ", stderr.String()) | 			// log.Println("stderr: ", stderr.String()) | ||||||
| 
 | 
 | ||||||
| 			return stdout.String(), stderr.String(), ErrDockertestCommandFailed | 			return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return stdout.String(), stderr.String(), nil | 		return stdout.String(), stderr.String(), nil | ||||||
| 	case <-time.After(execConfig.timeout): | 	case <-time.After(execConfig.timeout): | ||||||
| 
 | 
 | ||||||
| 		return stdout.String(), stderr.String(), ErrDockertestCommandTimeout | 		return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user