vault/builtin/logical/pki/dnstest/server.go
Steven Clark 5ce57dbd00
Fix incorrect role ttl parameters in ACME tests (#21585)
- The ACME tests were using ttl_duration and max_ttl_duration instead
   of ttl and max_ttl as input parameters to roles.
 - Add missing copyright headers
2023-07-05 14:17:15 -04:00

419 lines
11 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dnstest
import (
"context"
"fmt"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
"github.com/hashicorp/vault/sdk/helper/docker"
"github.com/stretchr/testify/require"
)
type TestServer struct {
t *testing.T
ctx context.Context
runner *docker.Runner
network string
startup *docker.Service
lock sync.Mutex
serial int
forwarders []string
domains []string
records map[string]map[string][]string // domain -> record -> value(s).
cleanup func()
}
func SetupResolver(t *testing.T, domain string) *TestServer {
return SetupResolverOnNetwork(t, domain, "")
}
func SetupResolverOnNetwork(t *testing.T, domain string, network string) *TestServer {
var ts TestServer
ts.t = t
ts.ctx = context.Background()
ts.domains = []string{domain}
ts.records = map[string]map[string][]string{}
ts.network = network
ts.setupRunner(domain, network)
ts.startContainer(network)
ts.PushConfig()
return &ts
}
func (ts *TestServer) setupRunner(domain string, network string) {
var err error
ts.runner, err = docker.NewServiceRunner(docker.RunOptions{
ImageRepo: "ubuntu/bind9",
ImageTag: "latest",
ContainerName: "bind9-dns-" + strings.ReplaceAll(domain, ".", "-"),
NetworkName: network,
Ports: []string{"53/udp"},
LogConsumer: func(s string) {
ts.t.Logf(s)
},
})
require.NoError(ts.t, err)
}
func (ts *TestServer) startContainer(network string) {
connUpFunc := func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
// Perform a simple connection to this resolver, even though the
// default configuration doesn't do anything useful.
peer, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return nil, fmt.Errorf("failed to resolve peer: %v / %v: %w", host, port, err)
}
conn, err := net.DialUDP("udp", nil, peer)
if err != nil {
return nil, fmt.Errorf("failed to dial peer: %v / %v / %v: %w", host, port, peer, err)
}
defer conn.Close()
_, err = conn.Write([]byte("garbage-in"))
if err != nil {
return nil, fmt.Errorf("failed to write to peer: %v / %v / %v: %w", host, port, peer, err)
}
// Connection worked.
return docker.NewServiceHostPort(host, port), nil
}
result, _, err := ts.runner.StartNewService(ts.ctx, true, true, connUpFunc)
require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0])
ts.startup = result
if ts.startup.StartResult.RealIP == "" {
mapping, err := ts.runner.GetNetworkAndAddresses(ts.startup.Container.ID)
require.NoError(ts.t, err, "failed to fetch network addresses to correct missing real IP address")
if len(network) == 0 {
require.Equal(ts.t, 1, len(mapping), "expected exactly one network address")
for network = range mapping {
// Because mapping is a map of network name->ip, we need
// to use the above range's assignment to get the name,
// as there is no other way of getting the keys of a map.
}
}
require.Contains(ts.t, mapping, network, "expected network to be part of the mapping")
ts.startup.StartResult.RealIP = mapping[network]
}
ts.t.Logf("[dnsserv] Addresses of DNS resolver: local=%v / container=%v", ts.GetLocalAddr(), ts.GetRemoteAddr())
}
func (ts *TestServer) buildNamedConf() string {
forwarders := "\n"
if len(ts.forwarders) > 0 {
forwarders = "\tforwarders {\n"
for _, forwarder := range ts.forwarders {
forwarders += "\t\t" + forwarder + ";\n"
}
forwarders += "\t};\n"
}
zones := "\n"
for _, domain := range ts.domains {
zones += fmt.Sprintf("zone \"%s\" {\n", domain)
zones += "\ttype primary;\n"
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain)
zones += "\tallow-update {\n\t\tnone;\n\t};\n"
zones += "\tnotify no;\n"
zones += "};\n\n"
}
// Reverse lookups are not handles as they're not presently necessary.
cfg := `options {
directory "/var/cache/bind";
dnssec-validation no;
` + forwarders + `
};
` + zones
return cfg
}
func (ts *TestServer) buildZoneFile(target string) string {
// One second TTL by default to allow quick refreshes.
zone := "$TTL 1;\n"
ts.serial += 1
zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target)
zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial)
zone += fmt.Sprintf("@\tIN\tNS\tns%d.%v.\n", ts.serial, target)
zone += fmt.Sprintf("ns%d.%v.\tIN\tA\t%v\n", ts.serial, target, "127.0.0.1")
for domain, records := range ts.records {
if !strings.HasSuffix(domain, target) {
continue
}
for recordType, values := range records {
for _, value := range values {
zone += fmt.Sprintf("%s.\tIN\t%s\t%s\n", domain, recordType, value)
}
}
}
return zone
}
func (ts *TestServer) pushNamedConf() {
contents := docker.NewBuildContext()
cfgPath := "/etc/bind/named.conf.options"
namedCfg := ts.buildNamedConf()
contents[cfgPath] = docker.PathContentsFromString(namedCfg)
contents[cfgPath].SetOwners(0, 142) // root, bind
ts.t.Logf("Generated bind9 config (%s):\n%v\n", cfgPath, namedCfg)
err := ts.runner.CopyTo(ts.startup.Container.ID, "/", contents)
require.NoError(ts.t, err, "failed pushing updated named.conf.options to container")
}
func (ts *TestServer) pushZoneFiles() {
contents := docker.NewBuildContext()
for _, domain := range ts.domains {
path := "/var/cache/bind/" + domain + ".zone"
zoneFile := ts.buildZoneFile(domain)
contents[path] = docker.PathContentsFromString(zoneFile)
contents[path].SetOwners(0, 142) // root, bind
ts.t.Logf("Generated bind9 zone file for %v (%s):\n%v\n", domain, path, zoneFile)
}
err := ts.runner.CopyTo(ts.startup.Container.ID, "/", contents)
require.NoError(ts.t, err, "failed pushing updated named.conf.options to container")
}
func (ts *TestServer) PushConfig() {
ts.lock.Lock()
defer ts.lock.Unlock()
// There's two cases here:
//
// 1. We've added a new top-level domain name. Here, we want to make
// sure the new zone file is pushed before we push the reference
// to it.
// 2. We've just added a new. Here, the order doesn't matter, but
// mostly likely the second push will be a no-op.
ts.pushZoneFiles()
ts.pushNamedConf()
// Wait until our config has taken.
corehelpers.RetryUntil(ts.t, 15*time.Second, func() error {
// bind reloads based on file mtime, touch files before starting
// to make sure it has been updated more recently than when the
// last update was written. Then issue a new SIGHUP.
for _, domain := range ts.domains {
path := "/var/cache/bind/" + domain + ".zone"
touchCmd := []string{"touch", path}
_, _, _, err := ts.runner.RunCmdWithOutput(ts.ctx, ts.startup.Container.ID, touchCmd)
if err != nil {
return fmt.Errorf("failed to update zone mtime: %w", err)
}
}
ts.runner.DockerAPI.ContainerKill(ts.ctx, ts.startup.Container.ID, "SIGHUP")
// Connect to our bind resolver.
resolver := &net.Resolver{
PreferGo: true,
StrictErrors: false,
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
d := net.Dialer{
Timeout: 10 * time.Second,
}
return d.DialContext(ctx, network, ts.GetLocalAddr())
},
}
// last domain has the given serial number, which also appears in the
// NS record so we can fetch it via Go.
lastDomain := ts.domains[len(ts.domains)-1]
records, err := resolver.LookupNS(ts.ctx, lastDomain)
if err != nil {
return fmt.Errorf("failed to lookup NS record for %v: %w", lastDomain, err)
}
if len(records) != 1 {
return fmt.Errorf("expected only 1 NS record for %v, got %v/%v", lastDomain, len(records), records)
}
expectedNS := fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain)
if records[0].Host != expectedNS {
return fmt.Errorf("expected to find NS %v, got %v indicating reload hadn't completed", expectedNS, records[0])
}
return nil
})
}
func (ts *TestServer) GetLocalAddr() string {
return ts.startup.Config.Address()
}
func (ts *TestServer) GetRemoteAddr() string {
return fmt.Sprintf("%s:%d", ts.startup.StartResult.RealIP, 53)
}
func (ts *TestServer) AddDomain(domain string) {
ts.lock.Lock()
defer ts.lock.Unlock()
for _, existing := range ts.domains {
if existing == domain {
return
}
}
ts.domains = append(ts.domains, domain)
}
func (ts *TestServer) AddRecord(domain string, record string, value string) {
ts.lock.Lock()
defer ts.lock.Unlock()
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
foundDomain = true
break
}
}
if !foundDomain {
ts.t.Fatalf("cannot add record %v/%v :: [%v] -- no domain zone matching (%v)", record, domain, value, ts.domains)
}
value = strings.TrimSpace(value)
if _, present := ts.records[domain]; !present {
ts.records[domain] = map[string][]string{}
}
if values, present := ts.records[domain][record]; present {
for _, candidate := range values {
if candidate == value {
// Already present; skip adding.
return
}
}
}
ts.records[domain][record] = append(ts.records[domain][record], value)
}
func (ts *TestServer) RemoveRecord(domain string, record string, value string) {
ts.lock.Lock()
defer ts.lock.Unlock()
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
foundDomain = true
break
}
}
if !foundDomain {
// Not found.
return
}
value = strings.TrimSpace(value)
if _, present := ts.records[domain]; !present {
// Not found.
return
}
var remaining []string
if values, present := ts.records[domain][record]; present {
for _, candidate := range values {
if candidate != value {
remaining = append(remaining, candidate)
}
}
}
ts.records[domain][record] = remaining
}
func (ts *TestServer) RemoveRecordsOfTypeForDomain(domain string, record string) {
ts.lock.Lock()
defer ts.lock.Unlock()
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
foundDomain = true
break
}
}
if !foundDomain {
// Not found.
return
}
if _, present := ts.records[domain]; !present {
// Not found.
return
}
delete(ts.records[domain], record)
}
func (ts *TestServer) RemoveRecordsForDomain(domain string) {
ts.lock.Lock()
defer ts.lock.Unlock()
foundDomain := false
for _, existing := range ts.domains {
if strings.HasSuffix(domain, existing) {
foundDomain = true
break
}
}
if !foundDomain {
// Not found.
return
}
if _, present := ts.records[domain]; !present {
// Not found.
return
}
ts.records[domain] = map[string][]string{}
}
func (ts *TestServer) RemoveAllRecords() {
ts.lock.Lock()
defer ts.lock.Unlock()
ts.records = map[string]map[string][]string{}
}
func (ts *TestServer) Cleanup() {
if ts.cleanup != nil {
ts.cleanup()
}
if ts.startup != nil && ts.startup.Cleanup != nil {
ts.startup.Cleanup()
}
}