diff --git a/sdk/helper/testhelpers/dnstest/server.go b/sdk/helper/testhelpers/dnstest/server.go index 07ae8dd32b..38e4509a64 100644 --- a/sdk/helper/testhelpers/dnstest/server.go +++ b/sdk/helper/testhelpers/dnstest/server.go @@ -19,6 +19,16 @@ import ( "github.com/stretchr/testify/require" ) +type TestServerConfig struct { + DockerNetworkName string + Logger hclog.Logger + Domains []ServerDomain + TSigKeys []TSigKey +} +type ServerDomain struct { + Domain string + TSIGKeyName string +} type TestServer struct { t *testing.T ctx context.Context @@ -28,11 +38,11 @@ type TestServer struct { network string startup *docker.Service - lock sync.Mutex - serial int - forwarders []string - domains []string - records map[string]map[string][]string // domain -> record -> value(s). + lock sync.Mutex + serial int + tsigKeys []TSigKey + domains []ServerDomain + records map[string]map[string][]string // domain -> record -> value(s). cleanup func() } @@ -46,21 +56,62 @@ func SetupResolverOnNetwork(t *testing.T, domain string, network string) *TestSe } func SetupResolverOnNetworkWithLogger(t *testing.T, domain string, network string, logger hclog.Logger) *TestServer { + config := TestServerConfig{ + Logger: logger, + DockerNetworkName: network, + Domains: []ServerDomain{ + {Domain: domain}, + }, + } + return SetupResolverWithConfig(t, config) +} + +func SetupResolverWithConfig(t *testing.T, config TestServerConfig) *TestServer { var ts TestServer ts.t = t ts.ctx = context.Background() - ts.domains = []string{domain} + ts.tsigKeys = config.TSigKeys + ts.domains = config.Domains ts.records = map[string]map[string][]string{} - ts.network = network - ts.log = logger + ts.network = config.DockerNetworkName + ts.log = config.Logger - ts.setupRunner(domain, network) - ts.startContainer(network) + validateReferencedTSIGKeysExist(t, config) + + if len(ts.domains) == 0 { + t.Fatal("no domains configured") + } + + ts.setupRunner(ts.domains[0].Domain, ts.network) + ts.startContainer(ts.network) ts.PushConfig() return &ts } +func validateReferencedTSIGKeysExist(t *testing.T, tsc TestServerConfig) { + var missingTSIGKeys []string + for _, domain := range tsc.Domains { + if len(domain.TSIGKeyName) == 0 { + continue + } + + found := false + for _, tsigKey := range tsc.TSigKeys { + if domain.TSIGKeyName == tsigKey.KeyName { + found = true + break + } + } + if !found { + missingTSIGKeys = append(missingTSIGKeys, fmt.Sprintf("TSIG key name: %q referenced by domain %q", domain.TSIGKeyName, domain.Domain)) + } + } + if len(missingTSIGKeys) > 0 { + t.Fatalf("missing TSIG keys: %s", missingTSIGKeys) + } +} + func (ts *TestServer) setupRunner(domain string, network string) { var err error ts.runner, err = docker.NewServiceRunner(docker.RunOptions{ @@ -102,7 +153,7 @@ func (ts *TestServer) startContainer(network string) { } result, _, err := ts.runner.StartNewService(ts.ctx, true, true, connUpFunc) - require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0]) + require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0].Domain) ts.startup = result if ts.startup.StartResult.RealIP == "" { @@ -124,44 +175,46 @@ func (ts *TestServer) startContainer(network string) { } func (ts *TestServer) buildNamedConf() string { - forwarders := "\n" - if len(ts.forwarders) > 0 { - forwarders = "\tforwarders {\n" - for _, forwarder := range ts.forwarders { - forwarders += "\t\t" + forwarder + ";\n" - } - forwarders += "\t};\n" - } - zones := "\n" - for _, domain := range ts.domains { - zones += fmt.Sprintf("zone \"%s\" {\n", domain) + for _, ds := range ts.domains { + updateType := "none" + if len(ds.TSIGKeyName) > 0 { + updateType = fmt.Sprintf(`key "%s"`, ds.TSIGKeyName) + } + + zones += fmt.Sprintf("zone \"%s\" {\n", ds.Domain) zones += "\ttype primary;\n" - zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain) - zones += "\tallow-update {\n\t\tnone;\n\t};\n" + zones += fmt.Sprintf("\tfile \"%s.zone\";\n", ds.Domain) + zones += fmt.Sprintf("\tallow-update {\n\t\t%s;\n\t};\n", updateType) zones += "\tnotify no;\n" zones += "};\n\n" } - // Reverse lookups are not handles as they're not presently necessary. + tsigKeys := "\n" + for _, key := range ts.tsigKeys { + tsigKeys += fmt.Sprintf("key \"%s\"{\n", key.KeyName) + tsigKeys += fmt.Sprintf("\talgorithm \"%s\";\n", key.Algorithm.String()) + tsigKeys += fmt.Sprintf("\tsecret \"%s\";\n", key.Secret) + tsigKeys += "};\n\n" + } cfg := `options { directory "/var/cache/bind"; - dnssec-validation no; - - ` + forwarders + ` + querylog yes; }; -` + zones +` + tsigKeys + zones return cfg } -func (ts *TestServer) buildZoneFile(target string) string { +func (ts *TestServer) buildZoneFile(sd ServerDomain) string { // One second TTL by default to allow quick refreshes. zone := "$TTL 1;\n" + target := sd.Domain + ts.serial += 1 zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target) zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial) @@ -200,10 +253,10 @@ func (ts *TestServer) pushZoneFiles() { contents := docker.NewBuildContext() for _, domain := range ts.domains { - path := "/var/cache/bind/" + domain + ".zone" + path := "/var/cache/bind/" + domain.Domain + ".zone" zoneFile := ts.buildZoneFile(domain) contents[path] = docker.PathContentsFromString(zoneFile) - contents[path].SetOwners(0, 142) // root, bind + contents[path].SetOwners(142, 142) // bind, bind allow updates through RFC2136 ts.log.Info(fmt.Sprintf("Generated bind9 zone file for %v (%s):\n%v\n", domain, path, zoneFile)) } @@ -238,7 +291,7 @@ func (ts *TestServer) PushConfig() { // to make sure it has been updated more recently than when the // last update was written. Then issue a new SIGHUP. for _, domain := range ts.domains { - path := "/var/cache/bind/" + domain + ".zone" + path := "/var/cache/bind/" + domain.Domain + ".zone" touchCmd := []string{"touch", path} _, _, _, err := ts.runner.RunCmdWithOutput(ts.ctx, ts.startup.Container.ID, touchCmd) @@ -264,14 +317,14 @@ func (ts *TestServer) PushConfig() { // last domain has the given serial number, which also appears in the // NS record so we can fetch it via Go. lastDomain := ts.domains[len(ts.domains)-1] - records, err := resolver.LookupNS(ts.ctx, lastDomain) + records, err := resolver.LookupNS(ts.ctx, lastDomain.Domain) if err != nil { - assert.NoError(ct, err, "failed to lookup NS record for %v", lastDomain) + assert.NoError(ct, err, "failed to lookup NS record for %v", lastDomain.Domain) return } - assert.Len(ct, records, 1, "expected only 1 NS record for %v", lastDomain) - assert.Equal(ct, fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain), records[0].Host, "reload hasn't completed") + assert.Len(ct, records, 1, "expected only 1 NS record for %v", lastDomain.Domain) + assert.Equal(ct, fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain.Domain), records[0].Host, "reload hasn't completed") }, 15*time.Second, 100*time.Millisecond) } @@ -288,12 +341,12 @@ func (ts *TestServer) AddDomain(domain string) { defer ts.lock.Unlock() for _, existing := range ts.domains { - if existing == domain { + if existing.Domain == domain { return } } - ts.domains = append(ts.domains, domain) + ts.domains = append(ts.domains, ServerDomain{Domain: domain}) } func (ts *TestServer) AddRecord(domain string, record string, value string) { @@ -302,7 +355,7 @@ func (ts *TestServer) AddRecord(domain string, record string, value string) { foundDomain := false for _, existing := range ts.domains { - if strings.HasSuffix(domain, existing) { + if strings.HasSuffix(domain, existing.Domain) { foundDomain = true break } @@ -334,7 +387,7 @@ func (ts *TestServer) RemoveRecord(domain string, record string, value string) { foundDomain := false for _, existing := range ts.domains { - if strings.HasSuffix(domain, existing) { + if strings.HasSuffix(domain, existing.Domain) { foundDomain = true break } @@ -368,7 +421,7 @@ func (ts *TestServer) RemoveRecordsOfTypeForDomain(domain string, record string) foundDomain := false for _, existing := range ts.domains { - if strings.HasSuffix(domain, existing) { + if strings.HasSuffix(domain, existing.Domain) { foundDomain = true break } @@ -392,7 +445,7 @@ func (ts *TestServer) RemoveRecordsForDomain(domain string) { foundDomain := false for _, existing := range ts.domains { - if strings.HasSuffix(domain, existing) { + if strings.HasSuffix(domain, existing.Domain) { foundDomain = true break } diff --git a/sdk/helper/testhelpers/dnstest/tsigkeys.go b/sdk/helper/testhelpers/dnstest/tsigkeys.go new file mode 100644 index 0000000000..fb95cb8b93 --- /dev/null +++ b/sdk/helper/testhelpers/dnstest/tsigkeys.go @@ -0,0 +1,89 @@ +// Copyright IBM Corp. 2026 +// SPDX-License-Identifier: MPL-2.0 + +package dnstest + +import ( + "crypto/rand" + "encoding/base64" + "fmt" +) + +// TSIGAlgorithm represents the supported TSIG algorithm types +type TSIGAlgorithm int + +const ( + HmacSHA1 TSIGAlgorithm = iota + HmacSHA224 + HmacSHA256 + HmacSHA384 + HmacSHA512 +) + +// String returns the string representation of the algorithm +func (a TSIGAlgorithm) String() string { + switch a { + case HmacSHA1: + return "hmac-sha1" + case HmacSHA224: + return "hmac-sha224" + case HmacSHA256: + return "hmac-sha256" + case HmacSHA384: + return "hmac-sha384" + case HmacSHA512: + return "hmac-sha512" + default: + return "unknown" + } +} + +// Bits returns the key size in bits for the algorithm +func (a TSIGAlgorithm) Bits() int { + switch a { + case HmacSHA1: + return 160 + case HmacSHA224: + return 224 + case HmacSHA256: + return 256 + case HmacSHA384: + return 384 + case HmacSHA512: + return 512 + default: + return 0 + } +} + +type TSigKey struct { + KeyName string + Algorithm TSIGAlgorithm + Secret string +} + +// GenerateTSIGKey generates a base64 std encoded TSIG key for the specified algorithm +func GenerateTSIGKey(keyName string, algorithm TSIGAlgorithm) (TSigKey, error) { + if keyName == "" { + return TSigKey{}, fmt.Errorf("empty key name") + } + + bits := algorithm.Bits() + if bits == 0 { + return TSigKey{}, fmt.Errorf("unsupported algorithm: %v", algorithm) + } + + // Calculate byte length from bits + byteLength := bits / 8 + + // Generate random bytes + key := make([]byte, byteLength) + _, err := rand.Read(key) + if err != nil { + return TSigKey{}, fmt.Errorf("failed to generate random key: %w", err) + } + + // Encode to base64 + encodedKey := base64.StdEncoding.EncodeToString(key) + return TSigKey{KeyName: keyName, Algorithm: algorithm, Secret: encodedKey}, nil +}