Merge remote-tracking branch 'remotes/from/ce/main'

This commit is contained in:
hc-github-team-secure-vault-core 2026-05-08 15:29:17 +00:00
commit 2546515307
2 changed files with 185 additions and 43 deletions

View File

@ -19,6 +19,16 @@ import (
"github.com/stretchr/testify/require"
)
type TestServerConfig struct {
DockerNetworkName string
Logger hclog.Logger
Domains []ServerDomain
TSigKeys []TSigKey
}
type ServerDomain struct {
Domain string
TSIGKeyName string
}
type TestServer struct {
t *testing.T
ctx context.Context
@ -28,11 +38,11 @@ type TestServer struct {
network string
startup *docker.Service
lock sync.Mutex
serial int
forwarders []string
domains []string
records map[string]map[string][]string // domain -> record -> value(s).
lock sync.Mutex
serial int
tsigKeys []TSigKey
domains []ServerDomain
records map[string]map[string][]string // domain -> record -> value(s).
cleanup func()
}
@ -46,21 +56,62 @@ func SetupResolverOnNetwork(t *testing.T, domain string, network string) *TestSe
}
func SetupResolverOnNetworkWithLogger(t *testing.T, domain string, network string, logger hclog.Logger) *TestServer {
config := TestServerConfig{
Logger: logger,
DockerNetworkName: network,
Domains: []ServerDomain{
{Domain: domain},
},
}
return SetupResolverWithConfig(t, config)
}
func SetupResolverWithConfig(t *testing.T, config TestServerConfig) *TestServer {
var ts TestServer
ts.t = t
ts.ctx = context.Background()
ts.domains = []string{domain}
ts.tsigKeys = config.TSigKeys
ts.domains = config.Domains
ts.records = map[string]map[string][]string{}
ts.network = network
ts.log = logger
ts.network = config.DockerNetworkName
ts.log = config.Logger
ts.setupRunner(domain, network)
ts.startContainer(network)
validateReferencedTSIGKeysExist(t, config)
if len(ts.domains) == 0 {
t.Fatal("no domains configured")
}
ts.setupRunner(ts.domains[0].Domain, ts.network)
ts.startContainer(ts.network)
ts.PushConfig()
return &ts
}
func validateReferencedTSIGKeysExist(t *testing.T, tsc TestServerConfig) {
var missingTSIGKeys []string
for _, domain := range tsc.Domains {
if len(domain.TSIGKeyName) == 0 {
continue
}
found := false
for _, tsigKey := range tsc.TSigKeys {
if domain.TSIGKeyName == tsigKey.KeyName {
found = true
break
}
}
if !found {
missingTSIGKeys = append(missingTSIGKeys, fmt.Sprintf("TSIG key name: %q referenced by domain %q", domain.TSIGKeyName, domain.Domain))
}
}
if len(missingTSIGKeys) > 0 {
t.Fatalf("missing TSIG keys: %s", missingTSIGKeys)
}
}
func (ts *TestServer) setupRunner(domain string, network string) {
var err error
ts.runner, err = docker.NewServiceRunner(docker.RunOptions{
@ -102,7 +153,7 @@ func (ts *TestServer) startContainer(network string) {
}
result, _, err := ts.runner.StartNewService(ts.ctx, true, true, connUpFunc)
require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0])
require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0].Domain)
ts.startup = result
if ts.startup.StartResult.RealIP == "" {
@ -124,44 +175,46 @@ func (ts *TestServer) startContainer(network string) {
}
func (ts *TestServer) buildNamedConf() string {
forwarders := "\n"
if len(ts.forwarders) > 0 {
forwarders = "\tforwarders {\n"
for _, forwarder := range ts.forwarders {
forwarders += "\t\t" + forwarder + ";\n"
}
forwarders += "\t};\n"
}
zones := "\n"
for _, domain := range ts.domains {
zones += fmt.Sprintf("zone \"%s\" {\n", domain)
for _, ds := range ts.domains {
updateType := "none"
if len(ds.TSIGKeyName) > 0 {
updateType = fmt.Sprintf(`key "%s"`, ds.TSIGKeyName)
}
zones += fmt.Sprintf("zone \"%s\" {\n", ds.Domain)
zones += "\ttype primary;\n"
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain)
zones += "\tallow-update {\n\t\tnone;\n\t};\n"
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", ds.Domain)
zones += fmt.Sprintf("\tallow-update {\n\t\t%s;\n\t};\n", updateType)
zones += "\tnotify no;\n"
zones += "};\n\n"
}
// Reverse lookups are not handles as they're not presently necessary.
tsigKeys := "\n"
for _, key := range ts.tsigKeys {
tsigKeys += fmt.Sprintf("key \"%s\"{\n", key.KeyName)
tsigKeys += fmt.Sprintf("\talgorithm \"%s\";\n", key.Algorithm.String())
tsigKeys += fmt.Sprintf("\tsecret \"%s\";\n", key.Secret)
tsigKeys += "};\n\n"
}
cfg := `options {
directory "/var/cache/bind";
dnssec-validation no;
` + forwarders + `
querylog yes;
};
` + zones
` + tsigKeys + zones
return cfg
}
func (ts *TestServer) buildZoneFile(target string) string {
func (ts *TestServer) buildZoneFile(sd ServerDomain) string {
// One second TTL by default to allow quick refreshes.
zone := "$TTL 1;\n"
target := sd.Domain
ts.serial += 1
zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target)
zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial)
@ -200,10 +253,10 @@ func (ts *TestServer) pushZoneFiles() {
contents := docker.NewBuildContext()
for _, domain := range ts.domains {
path := "/var/cache/bind/" + domain + ".zone"
path := "/var/cache/bind/" + domain.Domain + ".zone"
zoneFile := ts.buildZoneFile(domain)
contents[path] = docker.PathContentsFromString(zoneFile)
contents[path].SetOwners(0, 142) // root, bind
contents[path].SetOwners(142, 142) // bind, bind allow updates through RFC2136
ts.log.Info(fmt.Sprintf("Generated bind9 zone file for %v (%s):\n%v\n", domain, path, zoneFile))
}
@ -238,7 +291,7 @@ func (ts *TestServer) PushConfig() {
// to make sure it has been updated more recently than when the
// last update was written. Then issue a new SIGHUP.
for _, domain := range ts.domains {
path := "/var/cache/bind/" + domain + ".zone"
path := "/var/cache/bind/" + domain.Domain + ".zone"
touchCmd := []string{"touch", path}
_, _, _, err := ts.runner.RunCmdWithOutput(ts.ctx, ts.startup.Container.ID, touchCmd)
@ -264,14 +317,14 @@ func (ts *TestServer) PushConfig() {
// last domain has the given serial number, which also appears in the
// NS record so we can fetch it via Go.
lastDomain := ts.domains[len(ts.domains)-1]
records, err := resolver.LookupNS(ts.ctx, lastDomain)
records, err := resolver.LookupNS(ts.ctx, lastDomain.Domain)
if err != nil {
assert.NoError(ct, err, "failed to lookup NS record for %v", lastDomain)
assert.NoError(ct, err, "failed to lookup NS record for %v", lastDomain.Domain)
return
}
assert.Len(ct, records, 1, "expected only 1 NS record for %v", lastDomain)
assert.Equal(ct, fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain), records[0].Host, "reload hasn't completed")
assert.Len(ct, records, 1, "expected only 1 NS record for %v", lastDomain.Domain)
assert.Equal(ct, fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain.Domain), records[0].Host, "reload hasn't completed")
}, 15*time.Second, 100*time.Millisecond)
}
@ -288,12 +341,12 @@ func (ts *TestServer) AddDomain(domain string) {
defer ts.lock.Unlock()
for _, existing := range ts.domains {
if existing == domain {
if existing.Domain == domain {
return
}
}
ts.domains = append(ts.domains, domain)
ts.domains = append(ts.domains, ServerDomain{Domain: domain})
}
func (ts *TestServer) AddRecord(domain string, record string, value string) {
@ -302,7 +355,7 @@ func (ts *TestServer) AddRecord(domain string, record string, value string) {
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
if strings.HasSuffix(domain, existing.Domain) {
foundDomain = true
break
}
@ -334,7 +387,7 @@ func (ts *TestServer) RemoveRecord(domain string, record string, value string) {
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
if strings.HasSuffix(domain, existing.Domain) {
foundDomain = true
break
}
@ -368,7 +421,7 @@ func (ts *TestServer) RemoveRecordsOfTypeForDomain(domain string, record string)
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
if strings.HasSuffix(domain, existing.Domain) {
foundDomain = true
break
}
@ -392,7 +445,7 @@ func (ts *TestServer) RemoveRecordsForDomain(domain string) {
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
if strings.HasSuffix(domain, existing.Domain) {
foundDomain = true
break
}

View File

@ -0,0 +1,89 @@
// Copyright IBM Corp. 2026
// SPDX-License-Identifier: MPL-2.0
package dnstest
import (
"crypto/rand"
"encoding/base64"
"fmt"
)
// TSIGAlgorithm represents the supported TSIG algorithm types
type TSIGAlgorithm int
const (
HmacSHA1 TSIGAlgorithm = iota
HmacSHA224
HmacSHA256
HmacSHA384
HmacSHA512
)
// String returns the string representation of the algorithm
func (a TSIGAlgorithm) String() string {
switch a {
case HmacSHA1:
return "hmac-sha1"
case HmacSHA224:
return "hmac-sha224"
case HmacSHA256:
return "hmac-sha256"
case HmacSHA384:
return "hmac-sha384"
case HmacSHA512:
return "hmac-sha512"
default:
return "unknown"
}
}
// Bits returns the key size in bits for the algorithm
func (a TSIGAlgorithm) Bits() int {
switch a {
case HmacSHA1:
return 160
case HmacSHA224:
return 224
case HmacSHA256:
return 256
case HmacSHA384:
return 384
case HmacSHA512:
return 512
default:
return 0
}
}
type TSigKey struct {
KeyName string
Algorithm TSIGAlgorithm
Secret string
}
// GenerateTSIGKey generates a base64 std encoded TSIG key for the specified algorithm
func GenerateTSIGKey(keyName string, algorithm TSIGAlgorithm) (TSigKey, error) {
if keyName == "" {
return TSigKey{}, fmt.Errorf("empty key name")
}
bits := algorithm.Bits()
if bits == 0 {
return TSigKey{}, fmt.Errorf("unsupported algorithm: %v", algorithm)
}
// Calculate byte length from bits
byteLength := bits / 8
// Generate random bytes
key := make([]byte, byteLength)
_, err := rand.Read(key)
if err != nil {
return TSigKey{}, fmt.Errorf("failed to generate random key: %w", err)
}
// Encode to base64
encodedKey := base64.StdEncoding.EncodeToString(key)
return TSigKey{KeyName: keyName, Algorithm: algorithm, Secret: encodedKey}, nil
}