omni/internal/backend/k8sproxy/middleware.go
Artem Chernyshev ed946b30a6
feat: display OMNI_ENDPOINT in the service account creation UI
Fixes: https://github.com/siderolabs/omni/issues/858

Signed-off-by: Artem Chernyshev <artem.chernyshev@talos-systems.com>
2025-01-29 15:27:36 +03:00

147 lines
4.2 KiB
Go

// Copyright (c) 2025 Sidero Labs, Inc.
//
// Use of this software is governed by the Business Source License
// included in the LICENSE file.
package k8sproxy
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/cosi-project/runtime/pkg/resource"
"github.com/golang-jwt/jwt/v4"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
"go.uber.org/zap"
"k8s.io/client-go/transport"
"github.com/siderolabs/omni/internal/backend/runtime/omni/audit"
"github.com/siderolabs/omni/internal/pkg/ctxstore"
)
const authorizationHeader = "Authorization"
// KeyProvider implements a function which returns a public key with a given key ID to verify JWT token.
type KeyProvider func(ctx context.Context, keyID string) (any, error)
// ClusterUUIDResolver resolves a cluster ID to its UUID.
type ClusterUUIDResolver func(ctx context.Context, clusterID resource.ID) (string, error)
// AuthorizeRequest checks for valid token in the request.
func AuthorizeRequest(next http.Handler, keyFunc KeyProvider, clusterUUIDResolver ClusterUUIDResolver) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
authorization := req.Header.Get(authorizationHeader)
bearer, tok, ok := strings.Cut(authorization, " ")
if !ok || bearer != "Bearer" {
ctxzap.Error(ctx, "invalid authorization header")
w.WriteHeader(http.StatusUnauthorized)
return
}
token, err := jwt.ParseWithClaims(tok, &claims{}, func(token *jwt.Token) (any, error) {
if token.Method.Alg() != jwt.SigningMethodRS256.Alg() {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
keyID, ok := token.Header["kid"].(string)
if !ok {
ctxzap.Error(ctx, "invalid token header", zap.Any("header", token.Header))
return nil, errors.New("invalid token header")
}
return keyFunc(ctx, keyID)
})
if err != nil {
ctxzap.Error(ctx, "failed to validate JWT token", zap.Error(err))
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(err.Error())) //nolint:errcheck
return
}
claims, _ := token.Claims.(*claims) //nolint:errcheck
clusterName := claims.Cluster
clusterUUID := claims.ClusterUUID
grpc_ctxtags.Extract(req.Context()).
Set("cluster", clusterName).
Set("cluster_uuid", clusterUUID).
Set("impersonate.user", claims.Subject).
Set("impersonate.groups", claims.Groups)
if clusterName == "" {
ctxzap.Error(ctx, "cluster name is empty")
w.WriteHeader(http.StatusUnauthorized)
return
}
// Allow JWTs without cluster UUID for backwards compatibility - use their "cluster" claim as the target cluster.
// If this is a newer JWT with the "cluster_uuid" claim, get the matching cluster uuid and validate it against the "cluster_uuid" claim.
if clusterUUID != "" {
resolvedClusterUUID, err := clusterUUIDResolver(ctx, clusterName)
if err != nil {
ctxzap.Error(ctx, "failed to resolve cluster UUID", zap.Error(err))
w.WriteHeader(http.StatusUnauthorized)
return
}
if resolvedClusterUUID != clusterUUID {
ctxzap.Error(ctx, "cluster UUID does not match cluster name")
w.WriteHeader(http.StatusUnauthorized)
return
}
}
// clone the request before modifying it
req = req.WithContext(ctxstore.WithValue(ctx, clusterContextKey{ClusterName: clusterName}))
// clean all headers which are going to be overridden
req.Header.Del(authorizationHeader)
req.Header.Del(transport.ImpersonateUserHeader)
req.Header.Del(transport.ImpersonateGroupHeader)
req.Header.Add(transport.ImpersonateUserHeader, claims.Subject)
//nolint:contextcheck
req = req.WithContext(ctxstore.WithValue(
req.Context(),
&audit.Data{
K8SAccess: &audit.K8SAccess{
FullMethodName: req.Method + " " + req.URL.Path,
Command: req.Header.Get("Kubectl-Command"),
Session: req.Header.Get("Kubectl-Session"),
ClusterName: clusterName,
ClusterUUID: clusterUUID,
},
Session: audit.Session{
UserAgent: req.Header.Get("User-Agent"),
Email: claims.Subject,
},
},
))
for _, group := range claims.Groups {
req.Header.Add(transport.ImpersonateGroupHeader, group)
}
next.ServeHTTP(w, req)
})
}