mirror of
https://github.com/siderolabs/omni.git
synced 2025-08-09 02:56:59 +02:00
Using so-called phantom types we can use the types themselves as keys directly without loosing performance. You no longer need to remember which type was attached to the thing you passed in context and can look up all fields access directly. Part of #37 Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
286 lines
7.0 KiB
Go
286 lines
7.0 KiB
Go
// Copyright (c) 2024 Sidero Labs, Inc.
|
|
//
|
|
// Use of this software is governed by the Business Source License
|
|
// included in the LICENSE file.
|
|
|
|
package k8sproxy_test
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cosi-project/runtime/pkg/resource"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap/zaptest"
|
|
"k8s.io/client-go/transport"
|
|
|
|
"github.com/siderolabs/omni/internal/backend/k8sproxy"
|
|
"github.com/siderolabs/omni/internal/pkg/ctxstore"
|
|
)
|
|
|
|
var mockClusterUUIDResolver = func(_ context.Context, clusterID resource.ID) (string, error) {
|
|
upper := strings.ToUpper(clusterID)
|
|
if upper == clusterID {
|
|
return "", fmt.Errorf("invalid test cluster ID - does not contain lowercase: %s", clusterID)
|
|
}
|
|
|
|
return upper, nil
|
|
}
|
|
|
|
type mockClaims struct {
|
|
ExpiresAt *jwt.NumericDate `json:"exp,omitempty"`
|
|
Cluster string `json:"cluster,omitempty"`
|
|
ClusterUUID string `json:"cluster_uuid,omitempty"`
|
|
Subject string `json:"sub,omitempty"`
|
|
Groups []string `json:"groups,omitempty"`
|
|
}
|
|
|
|
func (c *mockClaims) Valid() error {
|
|
return nil
|
|
}
|
|
|
|
func TestAuthorize(t *testing.T) {
|
|
reqCh := make(chan *http.Request, 1)
|
|
|
|
coreHandler := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
reqCh <- r.Clone(r.Context())
|
|
})
|
|
|
|
key1, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
key2, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
keyFunc := func(_ context.Context, keyID string) (any, error) {
|
|
switch keyID {
|
|
case "1":
|
|
return &key1.PublicKey, nil
|
|
case "2":
|
|
return &key2.PublicKey, nil
|
|
default:
|
|
return nil, errors.New("unknown key")
|
|
}
|
|
}
|
|
|
|
ts := httptest.NewServer(k8sproxy.AuthorizeRequest(coreHandler, keyFunc, mockClusterUUIDResolver))
|
|
defer ts.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
logger := zaptest.NewLogger(t)
|
|
|
|
ctx = ctxzap.ToContext(ctx, logger)
|
|
|
|
type testCase struct { //nolint:govet
|
|
name string
|
|
|
|
claims mockClaims
|
|
kid string
|
|
signingKey *rsa.PrivateKey
|
|
extraHeaders map[string]string
|
|
|
|
expectedCode int
|
|
|
|
expectedImpersonateUser string
|
|
expectedImpersonateGroups []string
|
|
expectedCluster string
|
|
}
|
|
|
|
testCases := []testCase{
|
|
{
|
|
name: "valid key - legacy with cluster name",
|
|
claims: mockClaims{
|
|
Cluster: "cluster1",
|
|
Subject: "user1",
|
|
Groups: []string{"group1", "group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "1",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusOK,
|
|
|
|
expectedImpersonateUser: "user1",
|
|
expectedImpersonateGroups: []string{"group1", "group2"},
|
|
expectedCluster: "cluster1",
|
|
},
|
|
{
|
|
name: "valid key1",
|
|
claims: mockClaims{
|
|
Cluster: "cluster1",
|
|
ClusterUUID: "CLUSTER1",
|
|
Subject: "user1",
|
|
Groups: []string{"group1", "group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "1",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusOK,
|
|
|
|
expectedImpersonateUser: "user1",
|
|
expectedImpersonateGroups: []string{"group1", "group2"},
|
|
expectedCluster: "cluster1",
|
|
},
|
|
{
|
|
name: "valid key2 + extra headers",
|
|
claims: mockClaims{
|
|
Cluster: "cluster2",
|
|
ClusterUUID: "CLUSTER2",
|
|
Subject: "user2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "2",
|
|
signingKey: key2,
|
|
extraHeaders: map[string]string{
|
|
transport.ImpersonateUserHeader: "foo",
|
|
transport.ImpersonateGroupHeader: "bar",
|
|
},
|
|
|
|
expectedCode: http.StatusOK,
|
|
|
|
expectedImpersonateUser: "user2",
|
|
expectedImpersonateGroups: []string{"group2"},
|
|
expectedCluster: "cluster2",
|
|
},
|
|
{
|
|
name: "cluster-uuid mismatch",
|
|
claims: mockClaims{
|
|
Cluster: "cluster-1",
|
|
ClusterUUID: "CLUSTER2",
|
|
Subject: "user2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "2",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "kid mismatch",
|
|
claims: mockClaims{
|
|
Cluster: "cluster2",
|
|
ClusterUUID: "CLUSTER2",
|
|
Subject: "user2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "2",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "wrong kid",
|
|
claims: mockClaims{
|
|
Cluster: "cluster2",
|
|
ClusterUUID: "CLUSTER2",
|
|
Subject: "user2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "3",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "malformed claims 1",
|
|
claims: mockClaims{
|
|
Subject: "user2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "1",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "malformed claims 2",
|
|
claims: mockClaims{
|
|
Cluster: "cluster2",
|
|
ClusterUUID: "CLUSTER2",
|
|
Groups: []string{"group2"},
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "1",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
{
|
|
name: "malformed claims 2",
|
|
claims: mockClaims{
|
|
Cluster: "cluster2",
|
|
ClusterUUID: "CLUSTER2",
|
|
Subject: "foo",
|
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour)),
|
|
},
|
|
kid: "1",
|
|
signingKey: key1,
|
|
|
|
expectedCode: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, nil)
|
|
require.NoError(t, err)
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, &tc.claims)
|
|
token.Header["kid"] = tc.kid
|
|
|
|
auth, err := token.SignedString(tc.signingKey)
|
|
require.NoError(t, err)
|
|
|
|
req.Header["Authorization"] = []string{"Bearer " + auth}
|
|
|
|
for k, v := range tc.extraHeaders {
|
|
req.Header.Set(k, v)
|
|
}
|
|
|
|
resp, err := http.DefaultClient.Do(req) //nolint:bodyclose
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, tc.expectedCode, resp.StatusCode)
|
|
|
|
if tc.expectedCode != http.StatusOK {
|
|
return
|
|
}
|
|
|
|
var receivedReq *http.Request
|
|
|
|
select {
|
|
case receivedReq = <-reqCh:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout")
|
|
}
|
|
|
|
assert.Equal(t, []string{tc.expectedImpersonateUser}, receivedReq.Header.Values(transport.ImpersonateUserHeader))
|
|
assert.Equal(t, tc.expectedImpersonateGroups, receivedReq.Header.Values(transport.ImpersonateGroupHeader))
|
|
assert.Nil(t, receivedReq.Header.Values("Authorization"))
|
|
|
|
v, ok := ctxstore.Value[k8sproxy.ClusterContextKey](receivedReq.Context()) //nolint:contextcheck
|
|
assert.True(t, ok)
|
|
assert.Equal(t, tc.expectedCluster, v.ClusterName)
|
|
})
|
|
}
|
|
}
|