tailscale/cmd/k8s-operator/tsclient_test.go
Tom Proctor d4c5b278b3 cmd/k8s-operator: support workload identity federation
The feature is currently in private alpha, so requires a tailnet feature
flag. Initially focuses on supporting the operator's own auth, because the
operator is the only device we maintain that uses static long-lived
credentials. All other operator-created devices use single-use auth keys.

Testing steps:

* Create a cluster with an API server accessible over public internet
* kubectl get --raw /.well-known/openid-configuration | jq '.issuer'
* Create a federated OAuth client in the Tailscale admin console with:
  * The issuer from the previous step
  * Subject claim `system:serviceaccount:tailscale:operator`
  * Write scopes services, devices:core, auth_keys
  * Tag tag:k8s-operator
* Allow the Tailscale control plane to get the public portion of
  the ServiceAccount token signing key without authentication:
  * kubectl create clusterrolebinding oidc-discovery \
      --clusterrole=system:service-account-issuer-discovery \
      --group=system:unauthenticated
* helm install --set oauth.clientId=... --set oauth.audience=...

Updates #17457

Change-Id: Ib29c85ba97b093c70b002f4f41793ffc02e6c6e9
Signed-off-by: Tom Proctor <tomhjp@users.noreply.github.com>
2025-11-07 14:24:24 +00:00

136 lines
3.9 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !plan9
package main
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"go.uber.org/zap"
"golang.org/x/oauth2"
)
func TestNewStaticClient(t *testing.T) {
const (
clientIDFile = "client-id"
clientSecretFile = "client-secret"
)
tmp := t.TempDir()
clientIDPath := filepath.Join(tmp, clientIDFile)
if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", clientIDPath, err)
}
clientSecretPath := filepath.Join(tmp, clientSecretFile)
if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", clientSecretPath, err)
}
srv := testAPI(t, 3600)
cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL)
if err != nil {
t.Fatalf("error creating Tailscale client: %v", err)
}
resp, err := cl.HTTPClient.Get(srv.URL)
if err != nil {
t.Fatalf("error making test API call: %v", err)
}
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "")
if string(got) != want {
t.Errorf("got %q; want %q", got, want)
}
}
func TestNewWorkloadIdentityClient(t *testing.T) {
// 5 seconds is within expiryDelta leeway, so the access token will
// immediately be considered expired and get refreshed on each access.
srv := testAPI(t, 5)
cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL)
if err != nil {
t.Fatalf("error creating Tailscale client: %v", err)
}
// Modify the path where the JWT will be read from.
oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport)
if !ok {
t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport)
}
jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource)
if !ok {
t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source)
}
tmp := t.TempDir()
jwtPath := filepath.Join(tmp, "token")
jwtTokenSource.jwtPath = jwtPath
for _, jwt := range []string{"test-jwt", "updated-test-jwt"} {
if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil {
t.Fatalf("error writing test file %q: %v", jwtPath, err)
}
resp, err := cl.HTTPClient.Get(srv.URL)
if err != nil {
t.Fatalf("error making test API call: %v", err)
}
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("error reading response body: %v", err)
}
if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want {
t.Errorf("got %q; want %q", got, want)
}
}
}
func testAPI(t *testing.T, expirationSeconds int) *httptest.Server {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("test server got request: %s %s", r.Method, r.URL.Path)
switch r.URL.Path {
case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange":
id, secret, ok := r.BasicAuth()
if !ok {
t.Fatal("missing or invalid basic auth")
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]any{
"access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")),
"token_type": "Bearer",
"expires_in": expirationSeconds,
}); err != nil {
t.Fatalf("error writing response: %v", err)
}
case "/":
// Echo back the authz header for test assertions.
_, err := w.Write([]byte(r.Header.Get("Authorization")))
if err != nil {
t.Fatalf("error writing response: %v", err)
}
default:
w.WriteHeader(http.StatusNotFound)
}
}))
t.Cleanup(srv.Close)
return srv
}
func testToken(path, id, secret, jwt string) string {
return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt)
}