2024-10-02 13:41:55 -04:00

317 lines
8.6 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cert
import (
"context"
"fmt"
"os"
"path"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/command/agentproxyshared/auth"
)
func TestCertAuthMethod_Authenticate(t *testing.T) {
config := &auth.AuthConfig{
Logger: hclog.NewNullLogger(),
MountPath: "cert-test",
Config: map[string]interface{}{
"name": "foo",
},
}
method, err := NewCertAuthMethod(config)
if err != nil {
t.Fatal(err)
}
defer method.Shutdown()
client, err := api.NewClient(nil)
if err != nil {
t.Fatal(err)
}
loginPath, _, authMap, err := method.Authenticate(context.Background(), client)
if err != nil {
t.Fatal(err)
}
expectedLoginPath := path.Join(config.MountPath, "/login")
if loginPath != expectedLoginPath {
t.Fatalf("mismatch on login path: got: %s, expected: %s", loginPath, expectedLoginPath)
}
expectedAuthMap := map[string]interface{}{
"name": config.Config["name"],
}
if !reflect.DeepEqual(authMap, expectedAuthMap) {
t.Fatalf("mismatch on login path:\ngot:\n\t%v\nexpected:\n\t%v", authMap, expectedAuthMap)
}
}
func TestCertAuthMethod_AuthClient_withoutCerts(t *testing.T) {
config := &auth.AuthConfig{
Logger: hclog.NewNullLogger(),
MountPath: "cert-test",
Config: map[string]interface{}{
"name": "without-certs",
},
}
method, err := NewCertAuthMethod(config)
if err != nil {
t.Fatal(err)
}
defer method.Shutdown()
client, err := api.NewClient(api.DefaultConfig())
if err != nil {
t.Fatal(err)
}
clientToUse, err := method.(auth.AuthMethodWithClient).AuthClient(client)
if err != nil {
t.Fatal(err)
}
if client != clientToUse {
t.Fatal("error: expected AuthClient to return back original client")
}
}
func TestCertAuthMethod_AuthClient_withCerts(t *testing.T) {
clientCert, err := os.Open("./test-fixtures/keys/cert.pem")
if err != nil {
t.Fatal(err)
}
defer clientCert.Close()
clientKey, err := os.Open("./test-fixtures/keys/key.pem")
if err != nil {
t.Fatal(err)
}
defer clientKey.Close()
config := &auth.AuthConfig{
Logger: hclog.NewNullLogger(),
MountPath: "cert-test",
Config: map[string]interface{}{
"name": "with-certs",
"client_cert": clientCert.Name(),
"client_key": clientKey.Name(),
},
}
method, err := NewCertAuthMethod(config)
if err != nil {
t.Fatal(err)
}
defer method.Shutdown()
client, err := api.NewClient(nil)
if err != nil {
t.Fatal(err)
}
clientToUse, err := method.(auth.AuthMethodWithClient).AuthClient(client)
if err != nil {
t.Fatal(err)
}
if client == clientToUse {
t.Fatal("expected client from AuthClient to be different from original client")
}
// Call AuthClient again to get back the cached client
cachedClient, err := method.(auth.AuthMethodWithClient).AuthClient(client)
if err != nil {
t.Fatal(err)
}
if cachedClient != clientToUse {
t.Fatal("expected client from AuthClient to return back a cached client")
}
}
func copyFile(from, to string) error {
data, err := os.ReadFile(from)
if err != nil {
return err
}
return os.WriteFile(to, data, 0o600)
}
// TestCertAuthMethod_AuthClient_withCertsReload makes the file change and ensures the cert auth method deliver the event.
func TestCertAuthMethod_AuthClient_withCertsReload(t *testing.T) {
// Initial the cert/key pair to temp path
certPath := filepath.Join(os.TempDir(), "app.crt")
keyPath := filepath.Join(os.TempDir(), "app.key")
if err := copyFile("./test-fixtures/keys/cert.pem", certPath); err != nil {
t.Fatal("copy cert file failed", err)
}
defer os.Remove(certPath)
if err := copyFile("./test-fixtures/keys/key.pem", keyPath); err != nil {
t.Fatal("copy key file failed", err)
}
defer os.Remove(keyPath)
config := &auth.AuthConfig{
Logger: hclog.NewNullLogger(),
MountPath: "cert-test",
Config: map[string]interface{}{
"name": "with-certs-reloaded",
"client_cert": certPath,
"client_key": keyPath,
"reload": true,
"reload_period": 1,
},
}
method, err := NewCertAuthMethod(config)
if err != nil {
t.Fatal(err)
}
defer method.Shutdown()
client, err := api.NewClient(nil)
if err != nil {
t.Fatal(err)
}
clientToUse, err := method.(auth.AuthMethodWithClient).AuthClient(client)
if err != nil {
t.Fatal(err)
}
if client == clientToUse {
t.Fatal("expected client from AuthClient to be different from original client")
}
// Call AuthClient again to get back a new client with reloaded certificates
reloadedClient, err := method.(auth.AuthMethodWithClient).AuthClient(client)
if err != nil {
t.Fatal(err)
}
if reloadedClient == clientToUse {
t.Fatal("expected client from AuthClient to return back a new client")
}
method.CredSuccess()
// Only make a change to the cert file, it doesn't match the key file so the client won't pick and load them.
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
if err = copyFile("./test-fixtures/keys/cert1.pem", certPath); err != nil {
t.Fatal("update cert file failed", err)
}
select {
case <-ctx.Done():
case <-method.NewCreds():
cancel()
t.Fatal("malformed cert should not be observed as a change")
}
// Make a change to the key file and now they are good to be picked.
if err = copyFile("./test-fixtures/keys/key1.pem", keyPath); err != nil {
t.Fatal("update key file failed", err)
}
ctx, cancel = context.WithTimeout(context.Background(), 3*time.Second)
select {
case <-ctx.Done():
t.Fatal("failed to watch the cert change: timeout")
case <-method.NewCreds():
cancel()
}
}
// TestCertAuthMethod_hashCert_withEmptyPaths tests hashCert() if it works well with optional options.
func TestCertAuthMethod_hashCert_withEmptyPaths(t *testing.T) {
c := &certMethod{
logger: hclog.NewNullLogger(),
}
// It skips empty file paths
sum, err := c.hashCert("", "", "")
if sum == "" || err != nil {
t.Fatal("hashCert() should skip empty file paths and succeed.")
}
emptySum := sum
// Only present ca cert
sum, err = c.hashCert("", "", "./test-fixtures/root/rootcacert.pem")
if sum == "" || err != nil {
t.Fatal("hashCert() should succeed when only present ca cert.")
}
// Only present client cert/key
sum, err = c.hashCert("./test-fixtures/keys/cert.pem", "./test-fixtures/keys/key.pem", "")
if sum == "" || err != nil {
fmt.Println(sum, err)
t.Fatal("hashCert() should succeed when only present client cert/key.")
}
// The client cert/key should be presented together or will be skipped
sum, err = c.hashCert("./test-fixtures/keys/cert.pem", "", "")
if sum == "" || err != nil {
t.Fatal("hashCert() should succeed when only present client cert.")
} else if sum != emptySum {
t.Fatal("hashCert() should skip the client cert/key when only present client cert.")
}
}
// TestCertAuthMethod_hashCert_withInvalidClientCert adds test cases for invalid input for hashCert().
func TestCertAuthMethod_hashCert_withInvalidClientCert(t *testing.T) {
c := &certMethod{
logger: hclog.NewNullLogger(),
}
// With mismatched cert/key pair
sum, err := c.hashCert("./test-fixtures/keys/cert1.pem", "./test-fixtures/keys/key.pem", "")
if sum != "" || err == nil {
t.Fatal("hashCert() should fail with invalid client cert.")
}
// With non-existed paths
sum, err = c.hashCert("./test-fixtures/keys/cert2.pem", "./test-fixtures/keys/key.pem", "")
if sum != "" || err == nil {
t.Fatal("hashCert() should fail with non-existed client cert path.")
}
}
// TestCertAuthMethod_hashCert_withChange tests hashCert() if it detects changes from both client cert/key and ca cert.
func TestCertAuthMethod_hashCert_withChange(t *testing.T) {
c := &certMethod{
logger: hclog.NewNullLogger(),
}
// A good first case.
sum, err := c.hashCert("./test-fixtures/keys/cert.pem", "./test-fixtures/keys/key.pem", "./test-fixtures/root/rootcacert.pem")
if sum == "" || err != nil {
t.Fatal("hashCert() shouldn't fail with a valid pair of cert/key.")
}
// Only change the ca cert from the first case.
sum1, err := c.hashCert("./test-fixtures/keys/cert.pem", "./test-fixtures/keys/key.pem", "./test-fixtures/keys/cert.pem")
if sum1 == "" || err != nil {
t.Fatal("hashCert() shouldn't fail with valid pair of cert/key.")
} else if sum == sum1 {
t.Fatal("The hash should be different with a different ca cert.")
}
// Only change the cert/key pair from the first case.
sum2, err := c.hashCert("./test-fixtures/keys/cert1.pem", "./test-fixtures/keys/key1.pem", "./test-fixtures/root/rootcacert.pem")
if sum2 == "" || err != nil {
t.Fatal("hashCert() shouldn't fail with a valid cert/key pair")
} else if sum == sum2 || sum1 == sum2 {
t.Fatal("The hash should be different with a different pair of cert/key.")
}
}