mirror of
https://github.com/hashicorp/vault.git
synced 2025-08-17 12:07:02 +02:00
- The ACME tests were using ttl_duration and max_ttl_duration instead of ttl and max_ttl as input parameters to roles. - Add missing copyright headers
419 lines
11 KiB
Go
419 lines
11 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package dnstest
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/vault/helper/testhelpers/corehelpers"
|
|
"github.com/hashicorp/vault/sdk/helper/docker"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type TestServer struct {
|
|
t *testing.T
|
|
ctx context.Context
|
|
|
|
runner *docker.Runner
|
|
network string
|
|
startup *docker.Service
|
|
|
|
lock sync.Mutex
|
|
serial int
|
|
forwarders []string
|
|
domains []string
|
|
records map[string]map[string][]string // domain -> record -> value(s).
|
|
|
|
cleanup func()
|
|
}
|
|
|
|
func SetupResolver(t *testing.T, domain string) *TestServer {
|
|
return SetupResolverOnNetwork(t, domain, "")
|
|
}
|
|
|
|
func SetupResolverOnNetwork(t *testing.T, domain string, network string) *TestServer {
|
|
var ts TestServer
|
|
ts.t = t
|
|
ts.ctx = context.Background()
|
|
ts.domains = []string{domain}
|
|
ts.records = map[string]map[string][]string{}
|
|
ts.network = network
|
|
|
|
ts.setupRunner(domain, network)
|
|
ts.startContainer(network)
|
|
ts.PushConfig()
|
|
|
|
return &ts
|
|
}
|
|
|
|
func (ts *TestServer) setupRunner(domain string, network string) {
|
|
var err error
|
|
ts.runner, err = docker.NewServiceRunner(docker.RunOptions{
|
|
ImageRepo: "ubuntu/bind9",
|
|
ImageTag: "latest",
|
|
ContainerName: "bind9-dns-" + strings.ReplaceAll(domain, ".", "-"),
|
|
NetworkName: network,
|
|
Ports: []string{"53/udp"},
|
|
LogConsumer: func(s string) {
|
|
ts.t.Logf(s)
|
|
},
|
|
})
|
|
require.NoError(ts.t, err)
|
|
}
|
|
|
|
func (ts *TestServer) startContainer(network string) {
|
|
connUpFunc := func(ctx context.Context, host string, port int) (docker.ServiceConfig, error) {
|
|
// Perform a simple connection to this resolver, even though the
|
|
// default configuration doesn't do anything useful.
|
|
peer, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", host, port))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve peer: %v / %v: %w", host, port, err)
|
|
}
|
|
|
|
conn, err := net.DialUDP("udp", nil, peer)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to dial peer: %v / %v / %v: %w", host, port, peer, err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
_, err = conn.Write([]byte("garbage-in"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to write to peer: %v / %v / %v: %w", host, port, peer, err)
|
|
}
|
|
|
|
// Connection worked.
|
|
return docker.NewServiceHostPort(host, port), nil
|
|
}
|
|
|
|
result, _, err := ts.runner.StartNewService(ts.ctx, true, true, connUpFunc)
|
|
require.NoError(ts.t, err, "failed to start dns resolver for "+ts.domains[0])
|
|
ts.startup = result
|
|
|
|
if ts.startup.StartResult.RealIP == "" {
|
|
mapping, err := ts.runner.GetNetworkAndAddresses(ts.startup.Container.ID)
|
|
require.NoError(ts.t, err, "failed to fetch network addresses to correct missing real IP address")
|
|
if len(network) == 0 {
|
|
require.Equal(ts.t, 1, len(mapping), "expected exactly one network address")
|
|
for network = range mapping {
|
|
// Because mapping is a map of network name->ip, we need
|
|
// to use the above range's assignment to get the name,
|
|
// as there is no other way of getting the keys of a map.
|
|
}
|
|
}
|
|
require.Contains(ts.t, mapping, network, "expected network to be part of the mapping")
|
|
ts.startup.StartResult.RealIP = mapping[network]
|
|
}
|
|
|
|
ts.t.Logf("[dnsserv] Addresses of DNS resolver: local=%v / container=%v", ts.GetLocalAddr(), ts.GetRemoteAddr())
|
|
}
|
|
|
|
func (ts *TestServer) buildNamedConf() string {
|
|
forwarders := "\n"
|
|
if len(ts.forwarders) > 0 {
|
|
forwarders = "\tforwarders {\n"
|
|
for _, forwarder := range ts.forwarders {
|
|
forwarders += "\t\t" + forwarder + ";\n"
|
|
}
|
|
forwarders += "\t};\n"
|
|
}
|
|
|
|
zones := "\n"
|
|
for _, domain := range ts.domains {
|
|
zones += fmt.Sprintf("zone \"%s\" {\n", domain)
|
|
zones += "\ttype primary;\n"
|
|
zones += fmt.Sprintf("\tfile \"%s.zone\";\n", domain)
|
|
zones += "\tallow-update {\n\t\tnone;\n\t};\n"
|
|
zones += "\tnotify no;\n"
|
|
zones += "};\n\n"
|
|
}
|
|
|
|
// Reverse lookups are not handles as they're not presently necessary.
|
|
|
|
cfg := `options {
|
|
directory "/var/cache/bind";
|
|
|
|
dnssec-validation no;
|
|
|
|
` + forwarders + `
|
|
};
|
|
|
|
` + zones
|
|
|
|
return cfg
|
|
}
|
|
|
|
func (ts *TestServer) buildZoneFile(target string) string {
|
|
// One second TTL by default to allow quick refreshes.
|
|
zone := "$TTL 1;\n"
|
|
|
|
ts.serial += 1
|
|
zone += fmt.Sprintf("@\tIN\tSOA\tns.%v.\troot.%v.\t(\n", target, target)
|
|
zone += fmt.Sprintf("\t\t\t%d;\n\t\t\t1;\n\t\t\t1;\n\t\t\t2;\n\t\t\t1;\n\t\t\t)\n\n", ts.serial)
|
|
zone += fmt.Sprintf("@\tIN\tNS\tns%d.%v.\n", ts.serial, target)
|
|
zone += fmt.Sprintf("ns%d.%v.\tIN\tA\t%v\n", ts.serial, target, "127.0.0.1")
|
|
|
|
for domain, records := range ts.records {
|
|
if !strings.HasSuffix(domain, target) {
|
|
continue
|
|
}
|
|
|
|
for recordType, values := range records {
|
|
for _, value := range values {
|
|
zone += fmt.Sprintf("%s.\tIN\t%s\t%s\n", domain, recordType, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
return zone
|
|
}
|
|
|
|
func (ts *TestServer) pushNamedConf() {
|
|
contents := docker.NewBuildContext()
|
|
cfgPath := "/etc/bind/named.conf.options"
|
|
namedCfg := ts.buildNamedConf()
|
|
contents[cfgPath] = docker.PathContentsFromString(namedCfg)
|
|
contents[cfgPath].SetOwners(0, 142) // root, bind
|
|
|
|
ts.t.Logf("Generated bind9 config (%s):\n%v\n", cfgPath, namedCfg)
|
|
|
|
err := ts.runner.CopyTo(ts.startup.Container.ID, "/", contents)
|
|
require.NoError(ts.t, err, "failed pushing updated named.conf.options to container")
|
|
}
|
|
|
|
func (ts *TestServer) pushZoneFiles() {
|
|
contents := docker.NewBuildContext()
|
|
|
|
for _, domain := range ts.domains {
|
|
path := "/var/cache/bind/" + domain + ".zone"
|
|
zoneFile := ts.buildZoneFile(domain)
|
|
contents[path] = docker.PathContentsFromString(zoneFile)
|
|
contents[path].SetOwners(0, 142) // root, bind
|
|
|
|
ts.t.Logf("Generated bind9 zone file for %v (%s):\n%v\n", domain, path, zoneFile)
|
|
}
|
|
|
|
err := ts.runner.CopyTo(ts.startup.Container.ID, "/", contents)
|
|
require.NoError(ts.t, err, "failed pushing updated named.conf.options to container")
|
|
}
|
|
|
|
func (ts *TestServer) PushConfig() {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
// There's two cases here:
|
|
//
|
|
// 1. We've added a new top-level domain name. Here, we want to make
|
|
// sure the new zone file is pushed before we push the reference
|
|
// to it.
|
|
// 2. We've just added a new. Here, the order doesn't matter, but
|
|
// mostly likely the second push will be a no-op.
|
|
ts.pushZoneFiles()
|
|
ts.pushNamedConf()
|
|
|
|
// Wait until our config has taken.
|
|
corehelpers.RetryUntil(ts.t, 15*time.Second, func() error {
|
|
// bind reloads based on file mtime, touch files before starting
|
|
// to make sure it has been updated more recently than when the
|
|
// last update was written. Then issue a new SIGHUP.
|
|
for _, domain := range ts.domains {
|
|
path := "/var/cache/bind/" + domain + ".zone"
|
|
touchCmd := []string{"touch", path}
|
|
|
|
_, _, _, err := ts.runner.RunCmdWithOutput(ts.ctx, ts.startup.Container.ID, touchCmd)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update zone mtime: %w", err)
|
|
}
|
|
}
|
|
ts.runner.DockerAPI.ContainerKill(ts.ctx, ts.startup.Container.ID, "SIGHUP")
|
|
|
|
// Connect to our bind resolver.
|
|
resolver := &net.Resolver{
|
|
PreferGo: true,
|
|
StrictErrors: false,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
return d.DialContext(ctx, network, ts.GetLocalAddr())
|
|
},
|
|
}
|
|
|
|
// last domain has the given serial number, which also appears in the
|
|
// NS record so we can fetch it via Go.
|
|
lastDomain := ts.domains[len(ts.domains)-1]
|
|
records, err := resolver.LookupNS(ts.ctx, lastDomain)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to lookup NS record for %v: %w", lastDomain, err)
|
|
}
|
|
|
|
if len(records) != 1 {
|
|
return fmt.Errorf("expected only 1 NS record for %v, got %v/%v", lastDomain, len(records), records)
|
|
}
|
|
|
|
expectedNS := fmt.Sprintf("ns%d.%v.", ts.serial, lastDomain)
|
|
if records[0].Host != expectedNS {
|
|
return fmt.Errorf("expected to find NS %v, got %v indicating reload hadn't completed", expectedNS, records[0])
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (ts *TestServer) GetLocalAddr() string {
|
|
return ts.startup.Config.Address()
|
|
}
|
|
|
|
func (ts *TestServer) GetRemoteAddr() string {
|
|
return fmt.Sprintf("%s:%d", ts.startup.StartResult.RealIP, 53)
|
|
}
|
|
|
|
func (ts *TestServer) AddDomain(domain string) {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
for _, existing := range ts.domains {
|
|
if existing == domain {
|
|
return
|
|
}
|
|
}
|
|
|
|
ts.domains = append(ts.domains, domain)
|
|
}
|
|
|
|
func (ts *TestServer) AddRecord(domain string, record string, value string) {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
foundDomain := false
|
|
for _, existing := range ts.domains {
|
|
if strings.HasSuffix(domain, existing) {
|
|
foundDomain = true
|
|
break
|
|
}
|
|
}
|
|
if !foundDomain {
|
|
ts.t.Fatalf("cannot add record %v/%v :: [%v] -- no domain zone matching (%v)", record, domain, value, ts.domains)
|
|
}
|
|
|
|
value = strings.TrimSpace(value)
|
|
if _, present := ts.records[domain]; !present {
|
|
ts.records[domain] = map[string][]string{}
|
|
}
|
|
|
|
if values, present := ts.records[domain][record]; present {
|
|
for _, candidate := range values {
|
|
if candidate == value {
|
|
// Already present; skip adding.
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
ts.records[domain][record] = append(ts.records[domain][record], value)
|
|
}
|
|
|
|
func (ts *TestServer) RemoveRecord(domain string, record string, value string) {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
foundDomain := false
|
|
for _, existing := range ts.domains {
|
|
if strings.HasSuffix(domain, existing) {
|
|
foundDomain = true
|
|
break
|
|
}
|
|
}
|
|
if !foundDomain {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
value = strings.TrimSpace(value)
|
|
if _, present := ts.records[domain]; !present {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
var remaining []string
|
|
if values, present := ts.records[domain][record]; present {
|
|
for _, candidate := range values {
|
|
if candidate != value {
|
|
remaining = append(remaining, candidate)
|
|
}
|
|
}
|
|
}
|
|
|
|
ts.records[domain][record] = remaining
|
|
}
|
|
|
|
func (ts *TestServer) RemoveRecordsOfTypeForDomain(domain string, record string) {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
foundDomain := false
|
|
for _, existing := range ts.domains {
|
|
if strings.HasSuffix(domain, existing) {
|
|
foundDomain = true
|
|
break
|
|
}
|
|
}
|
|
if !foundDomain {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
if _, present := ts.records[domain]; !present {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
delete(ts.records[domain], record)
|
|
}
|
|
|
|
func (ts *TestServer) RemoveRecordsForDomain(domain string) {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
foundDomain := false
|
|
for _, existing := range ts.domains {
|
|
if strings.HasSuffix(domain, existing) {
|
|
foundDomain = true
|
|
break
|
|
}
|
|
}
|
|
if !foundDomain {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
if _, present := ts.records[domain]; !present {
|
|
// Not found.
|
|
return
|
|
}
|
|
|
|
ts.records[domain] = map[string][]string{}
|
|
}
|
|
|
|
func (ts *TestServer) RemoveAllRecords() {
|
|
ts.lock.Lock()
|
|
defer ts.lock.Unlock()
|
|
|
|
ts.records = map[string]map[string][]string{}
|
|
}
|
|
|
|
func (ts *TestServer) Cleanup() {
|
|
if ts.cleanup != nil {
|
|
ts.cleanup()
|
|
}
|
|
if ts.startup != nil && ts.startup.Cleanup != nil {
|
|
ts.startup.Cleanup()
|
|
}
|
|
}
|