mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 16:11:03 +01:00 
			
		
		
		
	refactor OIDC callback aux functions
This commit is contained in:
		
							parent
							
								
									00d2a447f4
								
							
						
					
					
						commit
						a1e7e771ce
					
				
							
								
								
									
										185
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										185
									
								
								oidc.go
									
									
									
									
									
								
							| @ -21,6 +21,13 @@ import ( | ||||
| 
 | ||||
| const ( | ||||
| 	randomByteSize = 16 | ||||
| 
 | ||||
| 	errEmptyOIDCCallbackParams = Error("empty OIDC callback params") | ||||
| 	errNoOIDCIDToken           = Error("could not extract ID Token for OIDC callback") | ||||
| 	errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain") | ||||
| 	errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user") | ||||
| 	errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed") | ||||
| 	errOIDCMachineKeyMissing   = Error("could not get machine key from cache") | ||||
| ) | ||||
| 
 | ||||
| type IDTokenClaims struct { | ||||
| @ -136,18 +143,18 @@ func (h *Headscale) OIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	req *http.Request, | ||||
| ) { | ||||
| 	code, state, ok := validateOIDCCallbackParams(writer, req) | ||||
| 	if !ok { | ||||
| 	code, state, err := validateOIDCCallbackParams(writer, req) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	rawIDToken, ok := h.getIDTokenForOIDCCallback(writer, code, state) | ||||
| 	if !ok { | ||||
| 	rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	idToken, ok := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) | ||||
| 	if !ok { | ||||
| 	idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| @ -158,43 +165,43 @@ func (h *Headscale) OIDCCallback( | ||||
| 	// 	return | ||||
| 	// } | ||||
| 
 | ||||
| 	claims, ok := extractIDTokenClaims(writer, idToken) | ||||
| 	if !ok { | ||||
| 	claims, err := extractIDTokenClaims(writer, idToken) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if ok := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); !ok { | ||||
| 	if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if ok := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); !ok { | ||||
| 	if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	machineKey, ok := h.validateMachineForOIDCCallback(writer, state, claims) | ||||
| 	if !ok { | ||||
| 	machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims) | ||||
| 	if err != nil || machineExists { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	namespaceName, ok := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain) | ||||
| 	if !ok { | ||||
| 	namespaceName, err := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	// register the machine if it's new | ||||
| 	log.Debug().Msg("Registering new machine after successful callback") | ||||
| 
 | ||||
| 	namespace, ok := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName) | ||||
| 	if !ok { | ||||
| 	namespace, err := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	if ok := h.registerMachineForOIDCCallback(writer, namespace, machineKey); !ok { | ||||
| 	if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	content, ok := renderOIDCCallbackTemplate(writer, claims) | ||||
| 	if !ok { | ||||
| 	content, err := renderOIDCCallbackTemplate(writer, claims) | ||||
| 	if err != nil { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| @ -211,7 +218,7 @@ func (h *Headscale) OIDCCallback( | ||||
| func validateOIDCCallbackParams( | ||||
| 	writer http.ResponseWriter, | ||||
| 	req *http.Request, | ||||
| ) (string, string, bool) { | ||||
| ) (string, string, error) { | ||||
| 	code := req.URL.Query().Get("code") | ||||
| 	state := req.URL.Query().Get("state") | ||||
| 
 | ||||
| @ -226,16 +233,16 @@ func validateOIDCCallbackParams( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return "", "", false | ||||
| 		return "", "", errEmptyOIDCCallbackParams | ||||
| 	} | ||||
| 
 | ||||
| 	return code, state, true | ||||
| 	return code, state, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) getIDTokenForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	code, state string, | ||||
| ) (string, bool) { | ||||
| ) (string, error) { | ||||
| 	oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| @ -244,15 +251,15 @@ func (h *Headscale) getIDTokenForOIDCCallback( | ||||
| 			Msg("Could not exchange code for token") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, err := writer.Write([]byte("Could not exchange code for token")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("Could not exchange code for token")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return "", false | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	log.Trace(). | ||||
| @ -273,16 +280,16 @@ func (h *Headscale) getIDTokenForOIDCCallback( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return "", false | ||||
| 		return "", errNoOIDCIDToken | ||||
| 	} | ||||
| 
 | ||||
| 	return rawIDToken, true | ||||
| 	return rawIDToken, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) verifyIDTokenForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	rawIDToken string, | ||||
| ) (*oidc.IDToken, bool) { | ||||
| ) (*oidc.IDToken, error) { | ||||
| 	verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) | ||||
| 	idToken, err := verifier.Verify(context.Background(), rawIDToken) | ||||
| 	if err != nil { | ||||
| @ -292,24 +299,24 @@ func (h *Headscale) verifyIDTokenForOIDCCallback( | ||||
| 			Msg("failed to verify id token") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, err := writer.Write([]byte("Failed to verify id token")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("Failed to verify id token")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return idToken, true | ||||
| 	return idToken, nil | ||||
| } | ||||
| 
 | ||||
| func extractIDTokenClaims( | ||||
| 	writer http.ResponseWriter, | ||||
| 	idToken *oidc.IDToken, | ||||
| ) (*IDTokenClaims, bool) { | ||||
| ) (*IDTokenClaims, error) { | ||||
| 	var claims IDTokenClaims | ||||
| 	if err := idToken.Claims(claims); err != nil { | ||||
| 		log.Error(). | ||||
| @ -318,18 +325,18 @@ func extractIDTokenClaims( | ||||
| 			Msg("Failed to decode id token claims") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, err := writer.Write([]byte("Failed to decode id token claims")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("Failed to decode id token claims")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &claims, true | ||||
| 	return &claims, nil | ||||
| } | ||||
| 
 | ||||
| // validateOIDCAllowedDomains checks that if AllowedDomains is provided, | ||||
| @ -338,7 +345,7 @@ func validateOIDCAllowedDomains( | ||||
| 	writer http.ResponseWriter, | ||||
| 	allowedDomains []string, | ||||
| 	claims *IDTokenClaims, | ||||
| ) bool { | ||||
| ) error { | ||||
| 	if len(allowedDomains) > 0 { | ||||
| 		if at := strings.LastIndex(claims.Email, "@"); at < 0 || | ||||
| 			!IsStringInSlice(allowedDomains, claims.Email[at+1:]) { | ||||
| @ -353,11 +360,11 @@ func validateOIDCAllowedDomains( | ||||
| 					Msg("Failed to write response") | ||||
| 			} | ||||
| 
 | ||||
| 			return false | ||||
| 			return errOIDCAllowedDomains | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // validateOIDCAllowedUsers checks that if AllowedUsers is provided, | ||||
| @ -366,7 +373,7 @@ func validateOIDCAllowedUsers( | ||||
| 	writer http.ResponseWriter, | ||||
| 	allowedUsers []string, | ||||
| 	claims *IDTokenClaims, | ||||
| ) bool { | ||||
| ) error { | ||||
| 	if len(allowedUsers) > 0 && | ||||
| 		!IsStringInSlice(allowedUsers, claims.Email) { | ||||
| 		log.Error().Msg("authenticated principal does not match any allowed user") | ||||
| @ -380,10 +387,10 @@ func validateOIDCAllowedUsers( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return false | ||||
| 		return errOIDCAllowedUsers | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| // validateMachine retrieves machine information if it exist | ||||
| @ -394,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	state string, | ||||
| 	claims *IDTokenClaims, | ||||
| ) (*key.MachinePublic, bool) { | ||||
| ) (*key.MachinePublic, bool, error) { | ||||
| 	// retrieve machinekey from state cache | ||||
| 	machineKeyIf, machineKeyFound := h.registrationCache.Get(state) | ||||
| 	if !machineKeyFound { | ||||
| @ -410,7 +417,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, false, errOIDCInvalidMachineState | ||||
| 	} | ||||
| 
 | ||||
| 	var machineKey key.MachinePublic | ||||
| @ -423,15 +430,15 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 			Msg("could not parse machine public key") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusBadRequest) | ||||
| 		_, err := writer.Write([]byte("could not parse public key")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("could not parse public key")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, false, err | ||||
| 	} | ||||
| 
 | ||||
| 	if !machineKeyOK { | ||||
| @ -446,7 +453,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, false, errOIDCMachineKeyMissing | ||||
| 	} | ||||
| 
 | ||||
| 	// retrieve machine information if it exist | ||||
| @ -469,7 +476,7 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 				Msg("Failed to refresh machine") | ||||
| 			http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError) | ||||
| 
 | ||||
| 			return nil, false | ||||
| 			return nil, true, err | ||||
| 		} | ||||
| 
 | ||||
| 		var content bytes.Buffer | ||||
| @ -485,15 +492,15 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 
 | ||||
| 			writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 			writer.WriteHeader(http.StatusInternalServerError) | ||||
| 			_, err := writer.Write([]byte("Could not render OIDC callback template")) | ||||
| 			if err != nil { | ||||
| 			_, werr := writer.Write([]byte("Could not render OIDC callback template")) | ||||
| 			if werr != nil { | ||||
| 				log.Error(). | ||||
| 					Caller(). | ||||
| 					Err(err). | ||||
| 					Err(werr). | ||||
| 					Msg("Failed to write response") | ||||
| 			} | ||||
| 
 | ||||
| 			return nil, false | ||||
| 			return nil, true, err | ||||
| 		} | ||||
| 
 | ||||
| 		writer.Header().Set("Content-Type", "text/html; charset=utf-8") | ||||
| @ -506,17 +513,17 @@ func (h *Headscale) validateMachineForOIDCCallback( | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, true, nil | ||||
| 	} | ||||
| 
 | ||||
| 	return &machineKey, true | ||||
| 	return &machineKey, false, nil | ||||
| } | ||||
| 
 | ||||
| func getNamespaceName( | ||||
| 	writer http.ResponseWriter, | ||||
| 	claims *IDTokenClaims, | ||||
| 	stripEmaildomain bool, | ||||
| ) (string, bool) { | ||||
| ) (string, error) { | ||||
| 	namespaceName, err := NormalizeToFQDNRules( | ||||
| 		claims.Email, | ||||
| 		stripEmaildomain, | ||||
| @ -525,24 +532,24 @@ func getNamespaceName( | ||||
| 		log.Error().Err(err).Caller().Msgf("couldn't normalize email") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusInternalServerError) | ||||
| 		_, err := writer.Write([]byte("couldn't normalize email")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("couldn't normalize email")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return "", false | ||||
| 		return "", err | ||||
| 	} | ||||
| 
 | ||||
| 	return namespaceName, true | ||||
| 	return namespaceName, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	namespaceName string, | ||||
| ) (*Namespace, bool) { | ||||
| ) (*Namespace, error) { | ||||
| 	namespace, err := h.GetNamespace(namespaceName) | ||||
| 	if errors.Is(err, errNamespaceNotFound) { | ||||
| 		namespace, err = h.CreateNamespace(namespaceName) | ||||
| @ -554,15 +561,15 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( | ||||
| 				Msgf("could not create new namespace '%s'", namespaceName) | ||||
| 			writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 			writer.WriteHeader(http.StatusInternalServerError) | ||||
| 			_, err := writer.Write([]byte("could not create namespace")) | ||||
| 			if err != nil { | ||||
| 			_, werr := writer.Write([]byte("could not create namespace")) | ||||
| 			if werr != nil { | ||||
| 				log.Error(). | ||||
| 					Caller(). | ||||
| 					Err(err). | ||||
| 					Err(werr). | ||||
| 					Msg("Failed to write response") | ||||
| 			} | ||||
| 
 | ||||
| 			return nil, false | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} else if err != nil { | ||||
| 		log.Error(). | ||||
| @ -572,25 +579,25 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback( | ||||
| 			Msg("could not find or create namespace") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusInternalServerError) | ||||
| 		_, err := writer.Write([]byte("could not find or create namespace")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("could not find or create namespace")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return namespace, true | ||||
| 	return namespace, nil | ||||
| } | ||||
| 
 | ||||
| func (h *Headscale) registerMachineForOIDCCallback( | ||||
| 	writer http.ResponseWriter, | ||||
| 	namespace *Namespace, | ||||
| 	machineKey *key.MachinePublic, | ||||
| ) bool { | ||||
| ) error { | ||||
| 	machineKeyStr := MachinePublicKeyStripPrefix(*machineKey) | ||||
| 
 | ||||
| 	if _, err := h.RegisterMachineFromAuthCallback( | ||||
| @ -604,24 +611,24 @@ func (h *Headscale) registerMachineForOIDCCallback( | ||||
| 			Msg("could not register machine") | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusInternalServerError) | ||||
| 		_, err := writer.Write([]byte("could not register machine")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("could not register machine")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return false | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	return true | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func renderOIDCCallbackTemplate( | ||||
| 	writer http.ResponseWriter, | ||||
| 	claims *IDTokenClaims, | ||||
| ) (*bytes.Buffer, bool) { | ||||
| ) (*bytes.Buffer, error) { | ||||
| 	var content bytes.Buffer | ||||
| 	if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ | ||||
| 		User: claims.Email, | ||||
| @ -635,16 +642,16 @@ func renderOIDCCallbackTemplate( | ||||
| 
 | ||||
| 		writer.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		writer.WriteHeader(http.StatusInternalServerError) | ||||
| 		_, err := writer.Write([]byte("Could not render OIDC callback template")) | ||||
| 		if err != nil { | ||||
| 		_, werr := writer.Write([]byte("Could not render OIDC callback template")) | ||||
| 		if werr != nil { | ||||
| 			log.Error(). | ||||
| 				Caller(). | ||||
| 				Err(err). | ||||
| 				Err(werr). | ||||
| 				Msg("Failed to write response") | ||||
| 		} | ||||
| 
 | ||||
| 		return nil, false | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	return &content, true | ||||
| 	return &content, nil | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user