vault/api/auth/cert/cert_test.go
Steven Clark e31d45514d
Rework certificate authentication client api (#29931)
* Rework certificate authentication api

 - Use the passed in Vault api client to perform the connection
    - This provides namespace support, retry behaviors and uses
      the existing secret parsing logic instead of re-implementing it
 - Change the cert auth role to be an optional argument
 - Allow users to use a different cert auth mount point

* Clean up test name

Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>

---------

Co-authored-by: Amir Aslamov <amir.aslamov@hashicorp.com>
2025-03-14 15:48:52 -04:00

407 lines
11 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package cert
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"io"
"math/big"
mathrand "math/rand"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"testing"
"time"
"github.com/hashicorp/go-rootcerts"
"github.com/hashicorp/vault/api"
)
var extractMountRe = regexp.MustCompile(`^/v1/auth/([a-z]+)/login$`)
// TestLogin tests the login method of the CertAuth struct with the various overrides we support,
// the values the client uses to call the server with influence the returned client token so
// the test can validate we sent what we expected to.
func TestLogin(t *testing.T) {
certs := createCertificates(t)
// build up a server that will influence the returned token based on the request
ln := runTestServer(t, certs, func(w http.ResponseWriter, r *http.Request) {
// Make sure we have a client certificate
if r.TLS == nil || len(r.TLS.PeerCertificates) == 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("missing peer certificate"))
return
}
// Make sure we get the expected client certificate
if r.TLS.PeerCertificates[0].SerialNumber.Cmp(certs.parsedClientCert.SerialNumber) != 0 {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("incorrect peer certificate provided"))
return
}
// Extract the mount path from the URL we were called with
var mount string
if matches := extractMountRe.FindStringSubmatch(r.URL.Path); len(matches) == 2 {
mount = matches[1]
}
// Parse out the role name from the request body if provided
rawBody, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("failed reading body: " + err.Error()))
return
}
defer r.Body.Close()
data := make(map[string]interface{})
if err = json.Unmarshal(rawBody, &data); err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("failed parsing body: " + err.Error()))
return
}
roleName := "none"
if role, ok := data["name"].(string); ok {
roleName = role
}
// Build up our token that includes the mount of our cert auth and the role name
token := fmt.Sprintf("%s-%s-token", mount, roleName)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(fmt.Sprintf(`{"auth": {"client_token": "%s"}}`, token)))
})
url := fmt.Sprintf("https://%s", ln.Addr())
defaultAuthClient, err := NewDefaultCertAuthClient(url, &api.TLSConfig{
CACert: certs.caCert,
ClientCert: certs.clientCert,
ClientKey: certs.clientKey,
})
if err != nil {
t.Fatalf("failed to create default CertAuth: %v", err)
}
defConfig := api.DefaultConfig()
defConfig.Address = url
baseClient, err := api.NewClient(defConfig)
if err != nil {
t.Fatalf("failed to create CertAuth client based on base client: %v", err)
}
basedOnClient, err := NewCertAuthClient(baseClient, &api.TLSConfig{
CACert: certs.caCert,
ClientCert: certs.clientCert,
ClientKey: certs.clientKey,
})
if err != nil {
t.Fatalf("failed to create CertAuth based on another client: %v", err)
}
tests := []struct {
name string
opts []LoginOption
expectedToken string
}{
{"default", []LoginOption{}, "cert-none-token"},
{"with mount path", []LoginOption{WithMountPath("mycert")}, "mycert-none-token"},
{"with role", []LoginOption{WithRole("myrole")}, "cert-myrole-token"},
{"with mount and role", []LoginOption{WithMountPath("mycert"), WithRole("myrole")}, "mycert-myrole-token"},
}
for name, certAuthClient := range map[string]*api.Client{"default": defaultAuthClient, "based-on-client": basedOnClient} {
for _, test := range tests {
t.Run(fmt.Sprintf("%s-%s", name, test.name), func(t *testing.T) {
auth, err := NewCertAuth(test.opts...)
if err != nil {
t.Fatalf("failed to create CertAuth: %v", err)
}
secret, err := auth.Login(context.Background(), certAuthClient)
if err != nil {
t.Fatalf("failed to login: %v", err)
}
if secret == nil || secret.Auth == nil || secret.Auth.ClientToken != test.expectedToken {
t.Fatalf("unexpected response: %v", secret)
}
})
}
}
}
// TestNewDefaultCertAuthClient tests the NewDefaultCertAuthClient function validates inputs properly
func TestNewDefaultCertAuthClient(t *testing.T) {
certs := createCertificates(t)
addr := "https://127.0.0.1:8200"
tlsConfig := &api.TLSConfig{
CAPath: certs.caCert,
ClientCert: certs.clientCert,
ClientKey: certs.clientKey,
}
type args struct {
address string
tlsConfig *api.TLSConfig
}
tests := []struct {
name string
args args
wantErr bool
}{
{"missing address", args{"", tlsConfig}, true},
{"missing config", args{addr, nil}, true},
{"missing client cert", args{addr, &api.TLSConfig{}}, true},
{"ok", args{addr, tlsConfig}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewDefaultCertAuthClient(tt.args.address, tt.args.tlsConfig)
if (err != nil) != tt.wantErr {
t.Errorf("NewDefaultCertAuthClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if got == nil {
t.Errorf("NewDefaultCertAuthClient() returned nil client")
return
}
if got.Address() != tt.args.address {
t.Errorf("NewDefaultCertAuthClient() returned client with address %s, expected %s", got.Address(), tt.args.address)
}
// The Login test will validate that our TLS config was passed in properly.
}
})
}
}
// TestNewCertAuthClient tests the NewCertAuthClient function validates inputs properly
func TestNewCertAuthClient(t *testing.T) {
certs := createCertificates(t)
addr := "https://127.0.0.1:8200"
tlsConfig := &api.TLSConfig{
CAPath: certs.caCert,
ClientCert: certs.clientCert,
ClientKey: certs.clientKey,
}
defConfig := api.DefaultConfig()
defConfig.Address = addr
client, err := api.NewClient(defConfig)
if err != nil {
t.Fatalf("failed to create CertAuth: %v", err)
}
type args struct {
client *api.Client
tlsConfig *api.TLSConfig
}
tests := []struct {
name string
args args
wantErr bool
}{
{"missing address", args{nil, tlsConfig}, true},
{"missing config", args{client, nil}, true},
{"missing client cert", args{client, &api.TLSConfig{}}, true},
{"ok", args{client, tlsConfig}, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewCertAuthClient(tt.args.client, tt.args.tlsConfig)
if (err != nil) != tt.wantErr {
t.Errorf("NewDefaultCertAuthClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if got == nil {
t.Errorf("NewDefaultCertAuthClient() returned nil client")
return
}
if got.Address() != addr {
t.Errorf("NewDefaultCertAuthClient() returned client with address %s, expected %s", got.Address(), addr)
}
// The Login test will validate that our TLS config was passed in properly.
}
})
}
}
func runTestServer(t *testing.T, cert testCerts, fn http.HandlerFunc) net.Listener {
certPool, err := rootcerts.LoadCACerts(&rootcerts.Config{
CAPath: cert.caCert,
})
if err != nil {
t.Fatalf("failed to load CA certs: %v", err)
}
tlsConfig := &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: certPool,
RootCAs: certPool,
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("failed to create listener: %v", err)
}
sv := &http.Server{
TLSConfig: tlsConfig,
Handler: fn,
}
t.Cleanup(func() {
_ = sv.Close()
_ = ln.Close()
})
go func() {
_ = sv.ServeTLS(ln, cert.serverCert, cert.serverKey)
}()
return ln
}
type testCerts struct {
caCert string
caKey string
serverCert string
serverKey string
clientCert string
clientKey string
parsedClientCert *x509.Certificate
}
func createCertificates(t *testing.T) testCerts {
tempDir := t.TempDir()
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: "localhost",
},
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
BasicConstraintsValid: true,
IsCA: true,
}
caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
caBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
t.Fatal(err)
}
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
t.Fatal(err)
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
caCertFile := filepath.Join(tempDir, "ca_cert.pem")
err = os.WriteFile(caCertFile, pem.EncodeToMemory(caCertPEMBlock), 0o755)
if err != nil {
t.Fatal(err)
}
marshaledCAKey, err := x509.MarshalECPrivateKey(caKey)
if err != nil {
t.Fatal(err)
}
caKeyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledCAKey,
}
caKeyFile := filepath.Join(tempDir, "ca_key.pem")
err = os.WriteFile(caKeyFile, pem.EncodeToMemory(caKeyPEMBlock), 0o755)
if err != nil {
t.Fatal(err)
}
createCertificate := func(prefix string) *x509.Certificate {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: "example.com",
},
EmailAddresses: []string{"valid@example.com"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(262980 * time.Hour),
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, caCert, key.Public(), caKey)
if err != nil {
t.Fatal(err)
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
err = os.WriteFile(filepath.Join(tempDir, prefix+"_cert.pem"), pem.EncodeToMemory(certPEMBlock), 0o755)
if err != nil {
t.Fatal(err)
}
marshaledKey, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatal(err)
}
keyPEMBlock := &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: marshaledKey,
}
err = os.WriteFile(filepath.Join(tempDir, prefix+"_key.pem"), pem.EncodeToMemory(keyPEMBlock), 0o755)
if err != nil {
t.Fatal(err)
}
parsedCert, err := x509.ParseCertificate(certBytes)
if err != nil {
t.Fatal(err)
}
return parsedCert
}
clientCert := createCertificate("client")
createCertificate("server")
filepath.Join(tempDir, "client_cert.pem")
return testCerts{
caCert: caCertFile,
caKey: caKeyFile,
serverCert: filepath.Join(tempDir, "server_cert.pem"),
serverKey: filepath.Join(tempDir, "server_key.pem"),
clientCert: filepath.Join(tempDir, "client_cert.pem"),
clientKey: filepath.Join(tempDir, "client_key.pem"),
parsedClientCert: clientCert,
}
}