mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-11-03 17:41:03 +01:00 
			
		
		
		
	refactor OIDC callback aux functions
This commit is contained in:
		
							parent
							
								
									00d2a447f4
								
							
						
					
					
						commit
						a1e7e771ce
					
				
							
								
								
									
										185
									
								
								oidc.go
									
									
									
									
									
								
							
							
						
						
									
										185
									
								
								oidc.go
									
									
									
									
									
								
							@ -21,6 +21,13 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	randomByteSize = 16
 | 
						randomByteSize = 16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
 | 
				
			||||||
 | 
						errNoOIDCIDToken           = Error("could not extract ID Token for OIDC callback")
 | 
				
			||||||
 | 
						errOIDCAllowedDomains      = Error("authenticated principal does not match any allowed domain")
 | 
				
			||||||
 | 
						errOIDCAllowedUsers        = Error("authenticated principal does not match any allowed user")
 | 
				
			||||||
 | 
						errOIDCInvalidMachineState = Error("requested machine state key expired before authorisation completed")
 | 
				
			||||||
 | 
						errOIDCMachineKeyMissing   = Error("could not get machine key from cache")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type IDTokenClaims struct {
 | 
					type IDTokenClaims struct {
 | 
				
			||||||
@ -136,18 +143,18 @@ func (h *Headscale) OIDCCallback(
 | 
				
			|||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	req *http.Request,
 | 
						req *http.Request,
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
	code, state, ok := validateOIDCCallbackParams(writer, req)
 | 
						code, state, err := validateOIDCCallbackParams(writer, req)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	rawIDToken, ok := h.getIDTokenForOIDCCallback(writer, code, state)
 | 
						rawIDToken, err := h.getIDTokenForOIDCCallback(writer, code, state)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	idToken, ok := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
 | 
						idToken, err := h.verifyIDTokenForOIDCCallback(writer, rawIDToken)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -158,43 +165,43 @@ func (h *Headscale) OIDCCallback(
 | 
				
			|||||||
	// 	return
 | 
						// 	return
 | 
				
			||||||
	// }
 | 
						// }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	claims, ok := extractIDTokenClaims(writer, idToken)
 | 
						claims, err := extractIDTokenClaims(writer, idToken)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if ok := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); !ok {
 | 
						if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if ok := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); !ok {
 | 
						if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	machineKey, ok := h.validateMachineForOIDCCallback(writer, state, claims)
 | 
						machineKey, machineExists, err := h.validateMachineForOIDCCallback(writer, state, claims)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil || machineExists {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	namespaceName, ok := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain)
 | 
						namespaceName, err := getNamespaceName(writer, claims, h.cfg.OIDC.StripEmaildomain)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// register the machine if it's new
 | 
						// register the machine if it's new
 | 
				
			||||||
	log.Debug().Msg("Registering new machine after successful callback")
 | 
						log.Debug().Msg("Registering new machine after successful callback")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	namespace, ok := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName)
 | 
						namespace, err := h.findOrCreateNewNamespaceForOIDCCallback(writer, namespaceName)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if ok := h.registerMachineForOIDCCallback(writer, namespace, machineKey); !ok {
 | 
						if err := h.registerMachineForOIDCCallback(writer, namespace, machineKey); err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	content, ok := renderOIDCCallbackTemplate(writer, claims)
 | 
						content, err := renderOIDCCallbackTemplate(writer, claims)
 | 
				
			||||||
	if !ok {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -211,7 +218,7 @@ func (h *Headscale) OIDCCallback(
 | 
				
			|||||||
func validateOIDCCallbackParams(
 | 
					func validateOIDCCallbackParams(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	req *http.Request,
 | 
						req *http.Request,
 | 
				
			||||||
) (string, string, bool) {
 | 
					) (string, string, error) {
 | 
				
			||||||
	code := req.URL.Query().Get("code")
 | 
						code := req.URL.Query().Get("code")
 | 
				
			||||||
	state := req.URL.Query().Get("state")
 | 
						state := req.URL.Query().Get("state")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -226,16 +233,16 @@ func validateOIDCCallbackParams(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return "", "", false
 | 
							return "", "", errEmptyOIDCCallbackParams
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return code, state, true
 | 
						return code, state, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) getIDTokenForOIDCCallback(
 | 
					func (h *Headscale) getIDTokenForOIDCCallback(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	code, state string,
 | 
						code, state string,
 | 
				
			||||||
) (string, bool) {
 | 
					) (string, error) {
 | 
				
			||||||
	oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
 | 
						oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.Error().
 | 
							log.Error().
 | 
				
			||||||
@ -244,15 +251,15 @@ func (h *Headscale) getIDTokenForOIDCCallback(
 | 
				
			|||||||
			Msg("Could not exchange code for token")
 | 
								Msg("Could not exchange code for token")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusBadRequest)
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		_, err := writer.Write([]byte("Could not exchange code for token"))
 | 
							_, werr := writer.Write([]byte("Could not exchange code for token"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return "", false
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.Trace().
 | 
						log.Trace().
 | 
				
			||||||
@ -273,16 +280,16 @@ func (h *Headscale) getIDTokenForOIDCCallback(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return "", false
 | 
							return "", errNoOIDCIDToken
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return rawIDToken, true
 | 
						return rawIDToken, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) verifyIDTokenForOIDCCallback(
 | 
					func (h *Headscale) verifyIDTokenForOIDCCallback(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	rawIDToken string,
 | 
						rawIDToken string,
 | 
				
			||||||
) (*oidc.IDToken, bool) {
 | 
					) (*oidc.IDToken, error) {
 | 
				
			||||||
	verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
 | 
						verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
 | 
				
			||||||
	idToken, err := verifier.Verify(context.Background(), rawIDToken)
 | 
						idToken, err := verifier.Verify(context.Background(), rawIDToken)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
@ -292,24 +299,24 @@ func (h *Headscale) verifyIDTokenForOIDCCallback(
 | 
				
			|||||||
			Msg("failed to verify id token")
 | 
								Msg("failed to verify id token")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusBadRequest)
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		_, err := writer.Write([]byte("Failed to verify id token"))
 | 
							_, werr := writer.Write([]byte("Failed to verify id token"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return idToken, true
 | 
						return idToken, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func extractIDTokenClaims(
 | 
					func extractIDTokenClaims(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	idToken *oidc.IDToken,
 | 
						idToken *oidc.IDToken,
 | 
				
			||||||
) (*IDTokenClaims, bool) {
 | 
					) (*IDTokenClaims, error) {
 | 
				
			||||||
	var claims IDTokenClaims
 | 
						var claims IDTokenClaims
 | 
				
			||||||
	if err := idToken.Claims(claims); err != nil {
 | 
						if err := idToken.Claims(claims); err != nil {
 | 
				
			||||||
		log.Error().
 | 
							log.Error().
 | 
				
			||||||
@ -318,18 +325,18 @@ func extractIDTokenClaims(
 | 
				
			|||||||
			Msg("Failed to decode id token claims")
 | 
								Msg("Failed to decode id token claims")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusBadRequest)
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		_, err := writer.Write([]byte("Failed to decode id token claims"))
 | 
							_, werr := writer.Write([]byte("Failed to decode id token claims"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &claims, true
 | 
						return &claims, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
 | 
					// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
 | 
				
			||||||
@ -338,7 +345,7 @@ func validateOIDCAllowedDomains(
 | 
				
			|||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	allowedDomains []string,
 | 
						allowedDomains []string,
 | 
				
			||||||
	claims *IDTokenClaims,
 | 
						claims *IDTokenClaims,
 | 
				
			||||||
) bool {
 | 
					) error {
 | 
				
			||||||
	if len(allowedDomains) > 0 {
 | 
						if len(allowedDomains) > 0 {
 | 
				
			||||||
		if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
 | 
							if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
 | 
				
			||||||
			!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
 | 
								!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
 | 
				
			||||||
@ -353,11 +360,11 @@ func validateOIDCAllowedDomains(
 | 
				
			|||||||
					Msg("Failed to write response")
 | 
										Msg("Failed to write response")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return false
 | 
								return errOIDCAllowedDomains
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return true
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
 | 
					// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
 | 
				
			||||||
@ -366,7 +373,7 @@ func validateOIDCAllowedUsers(
 | 
				
			|||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	allowedUsers []string,
 | 
						allowedUsers []string,
 | 
				
			||||||
	claims *IDTokenClaims,
 | 
						claims *IDTokenClaims,
 | 
				
			||||||
) bool {
 | 
					) error {
 | 
				
			||||||
	if len(allowedUsers) > 0 &&
 | 
						if len(allowedUsers) > 0 &&
 | 
				
			||||||
		!IsStringInSlice(allowedUsers, claims.Email) {
 | 
							!IsStringInSlice(allowedUsers, claims.Email) {
 | 
				
			||||||
		log.Error().Msg("authenticated principal does not match any allowed user")
 | 
							log.Error().Msg("authenticated principal does not match any allowed user")
 | 
				
			||||||
@ -380,10 +387,10 @@ func validateOIDCAllowedUsers(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return false
 | 
							return errOIDCAllowedUsers
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return true
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// validateMachine retrieves machine information if it exist
 | 
					// validateMachine retrieves machine information if it exist
 | 
				
			||||||
@ -394,7 +401,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	state string,
 | 
						state string,
 | 
				
			||||||
	claims *IDTokenClaims,
 | 
						claims *IDTokenClaims,
 | 
				
			||||||
) (*key.MachinePublic, bool) {
 | 
					) (*key.MachinePublic, bool, error) {
 | 
				
			||||||
	// retrieve machinekey from state cache
 | 
						// retrieve machinekey from state cache
 | 
				
			||||||
	machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
 | 
						machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
 | 
				
			||||||
	if !machineKeyFound {
 | 
						if !machineKeyFound {
 | 
				
			||||||
@ -410,7 +417,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, false, errOIDCInvalidMachineState
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var machineKey key.MachinePublic
 | 
						var machineKey key.MachinePublic
 | 
				
			||||||
@ -423,15 +430,15 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
			Msg("could not parse machine public key")
 | 
								Msg("could not parse machine public key")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusBadRequest)
 | 
							writer.WriteHeader(http.StatusBadRequest)
 | 
				
			||||||
		_, err := writer.Write([]byte("could not parse public key"))
 | 
							_, werr := writer.Write([]byte("could not parse public key"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, false, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if !machineKeyOK {
 | 
						if !machineKeyOK {
 | 
				
			||||||
@ -446,7 +453,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, false, errOIDCMachineKeyMissing
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// retrieve machine information if it exist
 | 
						// retrieve machine information if it exist
 | 
				
			||||||
@ -469,7 +476,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
				Msg("Failed to refresh machine")
 | 
									Msg("Failed to refresh machine")
 | 
				
			||||||
			http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError)
 | 
								http.Error(writer, "Failed to refresh machine", http.StatusInternalServerError)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil, false
 | 
								return nil, true, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		var content bytes.Buffer
 | 
							var content bytes.Buffer
 | 
				
			||||||
@ -485,15 +492,15 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
								writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
			writer.WriteHeader(http.StatusInternalServerError)
 | 
								writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
			_, err := writer.Write([]byte("Could not render OIDC callback template"))
 | 
								_, werr := writer.Write([]byte("Could not render OIDC callback template"))
 | 
				
			||||||
			if err != nil {
 | 
								if werr != nil {
 | 
				
			||||||
				log.Error().
 | 
									log.Error().
 | 
				
			||||||
					Caller().
 | 
										Caller().
 | 
				
			||||||
					Err(err).
 | 
										Err(werr).
 | 
				
			||||||
					Msg("Failed to write response")
 | 
										Msg("Failed to write response")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil, false
 | 
								return nil, true, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/html; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/html; charset=utf-8")
 | 
				
			||||||
@ -506,17 +513,17 @@ func (h *Headscale) validateMachineForOIDCCallback(
 | 
				
			|||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, true, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &machineKey, true
 | 
						return &machineKey, false, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getNamespaceName(
 | 
					func getNamespaceName(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	claims *IDTokenClaims,
 | 
						claims *IDTokenClaims,
 | 
				
			||||||
	stripEmaildomain bool,
 | 
						stripEmaildomain bool,
 | 
				
			||||||
) (string, bool) {
 | 
					) (string, error) {
 | 
				
			||||||
	namespaceName, err := NormalizeToFQDNRules(
 | 
						namespaceName, err := NormalizeToFQDNRules(
 | 
				
			||||||
		claims.Email,
 | 
							claims.Email,
 | 
				
			||||||
		stripEmaildomain,
 | 
							stripEmaildomain,
 | 
				
			||||||
@ -525,24 +532,24 @@ func getNamespaceName(
 | 
				
			|||||||
		log.Error().Err(err).Caller().Msgf("couldn't normalize email")
 | 
							log.Error().Err(err).Caller().Msgf("couldn't normalize email")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusInternalServerError)
 | 
							writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		_, err := writer.Write([]byte("couldn't normalize email"))
 | 
							_, werr := writer.Write([]byte("couldn't normalize email"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return "", false
 | 
							return "", err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return namespaceName, true
 | 
						return namespaceName, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 | 
					func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	namespaceName string,
 | 
						namespaceName string,
 | 
				
			||||||
) (*Namespace, bool) {
 | 
					) (*Namespace, error) {
 | 
				
			||||||
	namespace, err := h.GetNamespace(namespaceName)
 | 
						namespace, err := h.GetNamespace(namespaceName)
 | 
				
			||||||
	if errors.Is(err, errNamespaceNotFound) {
 | 
						if errors.Is(err, errNamespaceNotFound) {
 | 
				
			||||||
		namespace, err = h.CreateNamespace(namespaceName)
 | 
							namespace, err = h.CreateNamespace(namespaceName)
 | 
				
			||||||
@ -554,15 +561,15 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 | 
				
			|||||||
				Msgf("could not create new namespace '%s'", namespaceName)
 | 
									Msgf("could not create new namespace '%s'", namespaceName)
 | 
				
			||||||
			writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
								writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
			writer.WriteHeader(http.StatusInternalServerError)
 | 
								writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
			_, err := writer.Write([]byte("could not create namespace"))
 | 
								_, werr := writer.Write([]byte("could not create namespace"))
 | 
				
			||||||
			if err != nil {
 | 
								if werr != nil {
 | 
				
			||||||
				log.Error().
 | 
									log.Error().
 | 
				
			||||||
					Caller().
 | 
										Caller().
 | 
				
			||||||
					Err(err).
 | 
										Err(werr).
 | 
				
			||||||
					Msg("Failed to write response")
 | 
										Msg("Failed to write response")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			return nil, false
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	} else if err != nil {
 | 
						} else if err != nil {
 | 
				
			||||||
		log.Error().
 | 
							log.Error().
 | 
				
			||||||
@ -572,25 +579,25 @@ func (h *Headscale) findOrCreateNewNamespaceForOIDCCallback(
 | 
				
			|||||||
			Msg("could not find or create namespace")
 | 
								Msg("could not find or create namespace")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusInternalServerError)
 | 
							writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		_, err := writer.Write([]byte("could not find or create namespace"))
 | 
							_, werr := writer.Write([]byte("could not find or create namespace"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return namespace, true
 | 
						return namespace, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *Headscale) registerMachineForOIDCCallback(
 | 
					func (h *Headscale) registerMachineForOIDCCallback(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	namespace *Namespace,
 | 
						namespace *Namespace,
 | 
				
			||||||
	machineKey *key.MachinePublic,
 | 
						machineKey *key.MachinePublic,
 | 
				
			||||||
) bool {
 | 
					) error {
 | 
				
			||||||
	machineKeyStr := MachinePublicKeyStripPrefix(*machineKey)
 | 
						machineKeyStr := MachinePublicKeyStripPrefix(*machineKey)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if _, err := h.RegisterMachineFromAuthCallback(
 | 
						if _, err := h.RegisterMachineFromAuthCallback(
 | 
				
			||||||
@ -604,24 +611,24 @@ func (h *Headscale) registerMachineForOIDCCallback(
 | 
				
			|||||||
			Msg("could not register machine")
 | 
								Msg("could not register machine")
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusInternalServerError)
 | 
							writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		_, err := writer.Write([]byte("could not register machine"))
 | 
							_, werr := writer.Write([]byte("could not register machine"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return false
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return true
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func renderOIDCCallbackTemplate(
 | 
					func renderOIDCCallbackTemplate(
 | 
				
			||||||
	writer http.ResponseWriter,
 | 
						writer http.ResponseWriter,
 | 
				
			||||||
	claims *IDTokenClaims,
 | 
						claims *IDTokenClaims,
 | 
				
			||||||
) (*bytes.Buffer, bool) {
 | 
					) (*bytes.Buffer, error) {
 | 
				
			||||||
	var content bytes.Buffer
 | 
						var content bytes.Buffer
 | 
				
			||||||
	if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
 | 
						if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
 | 
				
			||||||
		User: claims.Email,
 | 
							User: claims.Email,
 | 
				
			||||||
@ -635,16 +642,16 @@ func renderOIDCCallbackTemplate(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
		writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
							writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
 | 
				
			||||||
		writer.WriteHeader(http.StatusInternalServerError)
 | 
							writer.WriteHeader(http.StatusInternalServerError)
 | 
				
			||||||
		_, err := writer.Write([]byte("Could not render OIDC callback template"))
 | 
							_, werr := writer.Write([]byte("Could not render OIDC callback template"))
 | 
				
			||||||
		if err != nil {
 | 
							if werr != nil {
 | 
				
			||||||
			log.Error().
 | 
								log.Error().
 | 
				
			||||||
				Caller().
 | 
									Caller().
 | 
				
			||||||
				Err(err).
 | 
									Err(werr).
 | 
				
			||||||
				Msg("Failed to write response")
 | 
									Msg("Failed to write response")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		return nil, false
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return &content, true
 | 
						return &content, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user