mirror of
				https://github.com/juanfont/headscale.git
				synced 2025-10-31 08:01:34 +01:00 
			
		
		
		
	handle register auth errors (#2435)
* handle register auth errors This commit handles register auth errors as the Tailscale clients expect. It returns the error as part of a tailcfg.RegisterResponse and not as a http error. In addition it fixes a nil pointer panic triggered by not handling the errors as part of this chain. Fixes #2434 Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * changelog Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
		
							parent
							
								
									b220fb7d51
								
							
						
					
					
						commit
						bcff0eaae7
					
				| @ -14,6 +14,13 @@ | ||||
|   - View of config, policy, filter, ssh policy per node, connected nodes and | ||||
|     DERPmap | ||||
| 
 | ||||
| ## 0.25.1 (2025-02-18) | ||||
| 
 | ||||
| ### Changes | ||||
| 
 | ||||
| - Fix issue where registration errors are sent correctly | ||||
|   [#2435](https://github.com/juanfont/headscale/pull/2435) | ||||
| 
 | ||||
| ## 0.25.0 (2025-02-11) | ||||
| 
 | ||||
| ### BREAKING | ||||
|  | ||||
| @ -230,6 +230,10 @@ func (ns *noiseServer) NoisePollNetMapHandler( | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func regErr(err error) *tailcfg.RegisterResponse { | ||||
| 	return &tailcfg.RegisterResponse{Error: err.Error()} | ||||
| } | ||||
| 
 | ||||
| // NoiseRegistrationHandler handles the actual registration process of a node. | ||||
| func (ns *noiseServer) NoiseRegistrationHandler( | ||||
| 	writer http.ResponseWriter, | ||||
| @ -241,52 +245,47 @@ func (ns *noiseServer) NoiseRegistrationHandler( | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	registerRequest, registerResponse, err := func() (*tailcfg.RegisterRequest, []byte, error) { | ||||
| 	registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { | ||||
| 		var resp *tailcfg.RegisterResponse | ||||
| 		body, err := io.ReadAll(req.Body) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 			return &tailcfg.RegisterRequest{}, regErr(err) | ||||
| 		} | ||||
| 		var registerRequest tailcfg.RegisterRequest | ||||
| 		if err := json.Unmarshal(body, ®isterRequest); err != nil { | ||||
| 			return nil, nil, err | ||||
| 		var regReq tailcfg.RegisterRequest | ||||
| 		if err := json.Unmarshal(body, ®Req); err != nil { | ||||
| 			return ®Req, regErr(err) | ||||
| 		} | ||||
| 
 | ||||
| 		ns.nodeKey = registerRequest.NodeKey | ||||
| 		ns.nodeKey = regReq.NodeKey | ||||
| 
 | ||||
| 		resp, err := ns.headscale.handleRegister(req.Context(), registerRequest, ns.conn.Peer()) | ||||
| 		// TODO(kradalby): Here we could have two error types, one that is surfaced to the client | ||||
| 		// and one that returns 500. | ||||
| 		resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer()) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 			var httpErr HTTPError | ||||
| 			if errors.As(err, &httpErr) { | ||||
| 				resp = &tailcfg.RegisterResponse{ | ||||
| 					Error: httpErr.Msg, | ||||
| 				} | ||||
| 				return ®Req, resp | ||||
| 			} else { | ||||
| 			} | ||||
| 			return ®Req, regErr(err) | ||||
| 		} | ||||
| 
 | ||||
| 		respBody, err := json.Marshal(resp) | ||||
| 		if err != nil { | ||||
| 			return nil, nil, err | ||||
| 		} | ||||
| 
 | ||||
| 		return ®isterRequest, respBody, nil | ||||
| 		return ®Req, resp | ||||
| 	}() | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Error handling registration") | ||||
| 		http.Error(writer, "Internal server error", http.StatusInternalServerError) | ||||
| 	} | ||||
| 
 | ||||
| 	// Reject unsupported versions | ||||
| 	if rejectUnsupported(writer, registerRequest.Version, ns.machineKey, registerRequest.NodeKey) { | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	respBody, err := json.Marshal(registerResponse) | ||||
| 	if err != nil { | ||||
| 		httpError(writer, err) | ||||
| 		return | ||||
| 	} | ||||
| 
 | ||||
| 	writer.Header().Set("Content-Type", "application/json; charset=utf-8") | ||||
| 	writer.WriteHeader(http.StatusOK) | ||||
| 	_, err = writer.Write(registerResponse) | ||||
| 	if err != nil { | ||||
| 		log.Error(). | ||||
| 			Caller(). | ||||
| 			Err(err). | ||||
| 			Msg("Failed to write response") | ||||
| 	} | ||||
| 	writer.Write(respBody) | ||||
| } | ||||
|  | ||||
| @ -228,3 +228,99 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { | ||||
| 		assert.Equal(t, "user1@test.no", status.User[status.Self.UserID].LoginName) | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { | ||||
| 	IntegrationSkip(t) | ||||
| 	t.Parallel() | ||||
| 
 | ||||
| 	for _, https := range []bool{true, false} { | ||||
| 		t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { | ||||
| 			scenario, err := NewScenario(dockertestMaxWait()) | ||||
| 			assertNoErr(t, err) | ||||
| 			defer scenario.ShutdownAssertNoPanics(t) | ||||
| 
 | ||||
| 			spec := map[string]int{ | ||||
| 				"user1": len(MustTestVersions), | ||||
| 				"user2": len(MustTestVersions), | ||||
| 			} | ||||
| 
 | ||||
| 			opts := []hsic.Option{hsic.WithTestName("pingallbyip")} | ||||
| 			if https { | ||||
| 				opts = append(opts, []hsic.Option{ | ||||
| 					hsic.WithTLS(), | ||||
| 				}...) | ||||
| 			} | ||||
| 
 | ||||
| 			err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, opts...) | ||||
| 			assertNoErrHeadscaleEnv(t, err) | ||||
| 
 | ||||
| 			allClients, err := scenario.ListTailscaleClients() | ||||
| 			assertNoErrListClients(t, err) | ||||
| 
 | ||||
| 			err = scenario.WaitForTailscaleSync() | ||||
| 			assertNoErrSync(t, err) | ||||
| 
 | ||||
| 			// assertClientsState(t, allClients) | ||||
| 
 | ||||
| 			clientIPs := make(map[TailscaleClient][]netip.Addr) | ||||
| 			for _, client := range allClients { | ||||
| 				ips, err := client.IPs() | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) | ||||
| 				} | ||||
| 				clientIPs[client] = ips | ||||
| 			} | ||||
| 
 | ||||
| 			headscale, err := scenario.Headscale() | ||||
| 			assertNoErrGetHeadscale(t, err) | ||||
| 
 | ||||
| 			listNodes, err := headscale.ListNodes() | ||||
| 			assert.Equal(t, len(listNodes), len(allClients)) | ||||
| 			nodeCountBeforeLogout := len(listNodes) | ||||
| 			t.Logf("node count before logout: %d", nodeCountBeforeLogout) | ||||
| 
 | ||||
| 			for _, client := range allClients { | ||||
| 				err := client.Logout() | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) | ||||
| 				} | ||||
| 			} | ||||
| 
 | ||||
| 			err = scenario.WaitForTailscaleLogout() | ||||
| 			assertNoErrLogout(t, err) | ||||
| 
 | ||||
| 			t.Logf("all clients logged out") | ||||
| 
 | ||||
| 			// if the server is not running with HTTPS, we have to wait a bit before | ||||
| 			// reconnection as the newest Tailscale client has a measure that will only | ||||
| 			// reconnect over HTTPS if they saw a noise connection previously. | ||||
| 			// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 | ||||
| 			// https://github.com/juanfont/headscale/issues/2164 | ||||
| 			if !https { | ||||
| 				time.Sleep(5 * time.Minute) | ||||
| 			} | ||||
| 
 | ||||
| 			for userName := range spec { | ||||
| 				key, err := scenario.CreatePreAuthKey(userName, true, false) | ||||
| 				if err != nil { | ||||
| 					t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) | ||||
| 				} | ||||
| 
 | ||||
| 				// Expire the key so it can't be used | ||||
| 				_, err = headscale.Execute( | ||||
| 					[]string{ | ||||
| 						"headscale", | ||||
| 						"preauthkeys", | ||||
| 						"--user", | ||||
| 						userName, | ||||
| 						"expire", | ||||
| 						key.Key, | ||||
| 					}) | ||||
| 				assertNoErr(t, err) | ||||
| 
 | ||||
| 				err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) | ||||
| 				assert.ErrorContains(t, err, "authkey expired") | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user