diff --git a/core/dnsserver/config_test.go b/core/dnsserver/config_test.go new file mode 100644 index 000000000..a33545565 --- /dev/null +++ b/core/dnsserver/config_test.go @@ -0,0 +1,67 @@ +package dnsserver + +import ( + "testing" + + "github.com/coredns/caddy" +) + +func TestKeyForConfig(t *testing.T) { + tests := []struct { + name string + blockIndex int + blockKeyIndex int + expected string + }{ + {"zero_indices", 0, 0, "0:0"}, + {"positive_indices", 1, 2, "1:2"}, + {"larger_indices", 10, 5, "10:5"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := keyForConfig(tc.blockIndex, tc.blockKeyIndex) + if result != tc.expected { + t.Errorf("Expected %s, got %s for blockIndex %d and blockKeyIndex %d", + tc.expected, result, tc.blockIndex, tc.blockKeyIndex) + } + }) + } +} + +func TestGetConfig(t *testing.T) { + controller := caddy.NewTestController("dns", "") + initialCtx := controller.Context() + dnsCtx, ok := initialCtx.(*dnsContext) + if !ok { + t.Fatalf("controller.Context() did not return a *dnsContext, got %T", initialCtx) + } + if dnsCtx.keysToConfigs == nil { + t.Fatal("dnsCtx.keysToConfigs is nil; it should have been initialized by newContext") + } + + t.Run("returns and saves default config when config missing", func(t *testing.T) { + controller.ServerBlockIndex = 0 + controller.ServerBlockKeyIndex = 0 + key := keyForConfig(controller.ServerBlockIndex, controller.ServerBlockKeyIndex) + + // Ensure config doesn't exist initially for this specific key + delete(dnsCtx.keysToConfigs, key) + + cfg := GetConfig(controller) + if cfg == nil { + t.Fatal("GetConfig returned nil (should create and return a default)") + } + if len(cfg.ListenHosts) != 1 || cfg.ListenHosts[0] != "" { + t.Errorf("Expected default ListenHosts [\"\"] for auto-created config, got %v", cfg.ListenHosts) + } + + savedCfg, found := dnsCtx.keysToConfigs[key] + if !found { + t.Fatal("fallback did not save the default config into the context") + } + if savedCfg != cfg { + t.Fatal("config is not the same instance as the one saved in the context") + } + }) +} diff --git a/core/dnsserver/https_test.go b/core/dnsserver/https_test.go index 00ed366d7..3d50cda64 100644 --- a/core/dnsserver/https_test.go +++ b/core/dnsserver/https_test.go @@ -88,3 +88,72 @@ func TestDoHWriter_Request(t *testing.T) { }) } } + +func TestDoHWriter_Write(t *testing.T) { + tests := []struct { + name string + input []byte + wantErr bool + }{ + { + name: "valid DNS message", + // A minimal valid DNS query message + input: []byte{ + 0x00, 0x01, /* ID */ + 0x01, 0x00, /* Flags: query, recursion desired */ + 0x00, 0x01, /* Questions: 1 */ + 0x00, 0x00, /* Answer RRs: 0 */ + 0x00, 0x00, /* Authority RRs: 0 */ + 0x00, 0x00, /* Additional RRs: 0 */ + 0x03, 'w', 'w', 'w', + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, /* Null terminator for domain name */ + 0x00, 0x01, /* Type: A */ + 0x00, 0x01, /* Class: IN */ + }, + wantErr: false, + }, + { + name: "empty message", + input: []byte{}, + wantErr: true, // Expect an error because unpacking an empty message will fail + }, + { + name: "invalid DNS message", + input: []byte{0x00, 0x01, 0x02}, // Truncated message + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &DoHWriter{} + n, err := d.Write(tt.input) + + if (err != nil) != tt.wantErr { + t.Errorf("Write() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && n != len(tt.input) { + t.Errorf("Write() bytes written = %v, want %v", n, len(tt.input)) + } + if !tt.wantErr && d.Msg == nil { + t.Errorf("Write() d.Msg is nil, expected a parsed message") + } + }) + } +} + +func TestDoHWriter_Close(t *testing.T) { + d := &DoHWriter{} + if err := d.Close(); err != nil { + t.Errorf("Close() error = %v, want nil", err) + } +} + +func TestDoHWriter_TsigStatus(t *testing.T) { + d := &DoHWriter{} + if err := d.TsigStatus(); err != nil { + t.Errorf("TsigStatus() error = %v, want nil", err) + } +} diff --git a/core/dnsserver/onstartup_test.go b/core/dnsserver/onstartup_test.go index 41d4d82f8..4af376e01 100644 --- a/core/dnsserver/onstartup_test.go +++ b/core/dnsserver/onstartup_test.go @@ -37,3 +37,110 @@ func TestRegex1035PrefSyntax(t *testing.T) { } } } + +func TestStartUpZones(t *testing.T) { + tests := []struct { + name string + protocol string + addr string + zones map[string][]*Config + expectedOutput string + }{ + { + name: "no zones", + protocol: "dns://", + addr: "127.0.0.1:53", + zones: map[string][]*Config{}, + expectedOutput: "", + }, + { + name: "single zone valid syntax ip and port", + protocol: "dns://", + addr: "127.0.0.1:53", + zones: map[string][]*Config{"example.com.": nil}, + expectedOutput: "dns://example.com.:53 on 127.0.0.1\n", + }, + { + name: "single zone valid syntax port only", + protocol: "http://", + addr: ":8080", + zones: map[string][]*Config{"example.org.": nil}, + expectedOutput: "http://example.org.:8080\n", + }, + { + name: "single zone invalid syntax", + protocol: "tls://", + addr: "10.0.0.1:853", + zones: map[string][]*Config{"invalid-zone": nil}, + expectedOutput: "Warning: Domain \"invalid-zone\" does not follow RFC1035 preferred syntax\n" + + "tls://invalid-zone:853 on 10.0.0.1\n", + }, + { + name: "multiple zones sorted order", + protocol: "dns://", + addr: "localhost:5353", + zones: map[string][]*Config{ + "c-zone.com.": nil, + "a-zone.org.": nil, + "b-zone.net.": nil, + }, + expectedOutput: "dns://a-zone.org.:5353 on localhost\n" + + "dns://b-zone.net.:5353 on localhost\n" + + "dns://c-zone.com.:5353 on localhost\n", + }, + { + name: "addr parse error", + protocol: "grpc://", + addr: "[::1]:8080:extra", // Malformed, should cause SplitProtocolHostPort to error + zones: map[string][]*Config{"error.example.": nil}, + expectedOutput: "grpc://error.example.:[::1]:8080:extra\n", + }, + { + name: "root zone", + protocol: "dns://", + addr: "192.168.1.1:53", + zones: map[string][]*Config{".": nil}, + expectedOutput: "dns://.:53 on 192.168.1.1\n", + }, + { + name: "reverse zone", + protocol: "dns://", + addr: ":53", + zones: map[string][]*Config{"1.0.168.192.in-addr.arpa.": nil}, + expectedOutput: "dns://1.0.168.192.in-addr.arpa.:53\n", + }, + { + name: "multiple zones mixed syntax and addr handling", + protocol: "quic://", + addr: "coolserver.local:784", + zones: map[string][]*Config{ + "valid.net.": nil, + "_tcp.service.": nil, // Invalid syntax + "another.valid.com.": nil, + }, + expectedOutput: "Warning: Domain \"_tcp.service.\" does not follow RFC1035 preferred syntax\n" + + "quic://_tcp.service.:784 on coolserver.local\n" + + "quic://another.valid.com.:784 on coolserver.local\n" + + "quic://valid.net.:784 on coolserver.local\n", + }, + { + name: "zone with leading dash invalid", + protocol: "dns://", + addr: "127.0.0.1:53", + zones: map[string][]*Config{"-leadingdash.com.": nil}, + expectedOutput: "Warning: Domain \"-leadingdash.com.\" does not follow RFC1035 preferred syntax\n" + + "dns://-leadingdash.com.:53 on 127.0.0.1\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := startUpZones(tc.protocol, tc.addr, tc.zones) + if got != tc.expectedOutput { + // Use %q for expected and got to make differences in whitespace/newlines visible. + t.Errorf("startUpZones(%q, %q, ...) mismatch for test '%s':\nGot:\n%q\nExpected:\n%q", + tc.protocol, tc.addr, tc.name, got, tc.expectedOutput) + } + }) + } +} diff --git a/core/dnsserver/quic_test.go b/core/dnsserver/quic_test.go index 98d658a9a..4a4f408cd 100644 --- a/core/dnsserver/quic_test.go +++ b/core/dnsserver/quic_test.go @@ -1,7 +1,15 @@ package dnsserver import ( + "bytes" + "context" + "errors" + "net" "testing" + "time" + + "github.com/miekg/dns" + "github.com/quic-go/quic-go" ) func TestDoQWriterAddPrefix(t *testing.T) { @@ -18,3 +26,239 @@ func TestDoQWriterAddPrefix(t *testing.T) { t.Errorf("Expected prefixed size to be 3, got: %d", size) } } + +func TestDoQWriter_ResponseWriterMethods(t *testing.T) { + localAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + remoteAddr := &net.UDPAddr{IP: net.ParseIP("8.8.8.8"), Port: 53} + + writer := &DoQWriter{ + localAddr: localAddr, + remoteAddr: remoteAddr, + } + + if err := writer.TsigStatus(); err != nil { + t.Errorf("TsigStatus() returned an error: %v", err) + } + + // this is a no-op, just call it + writer.TsigTimersOnly(true) + writer.TsigTimersOnly(false) + + // this is a no-op, just call it + writer.Hijack() + + if addr := writer.LocalAddr(); addr != localAddr { + t.Errorf("LocalAddr() = %v, want %v", addr, localAddr) + } + + if addr := writer.RemoteAddr(); addr != remoteAddr { + t.Errorf("RemoteAddr() = %v, want %v", addr, remoteAddr) + } +} + +// mockQuicStream is a mock implementation of quic.Stream for testing. +type mockQuicStream struct { + writer func(p []byte) (n int, err error) + closer func() error + closed bool + data []byte +} + +func (m *mockQuicStream) Write(p []byte) (n int, err error) { + m.data = append(m.data, p...) + if m.writer != nil { + return m.writer(p) + } + return len(p), nil +} + +func (m *mockQuicStream) Close() error { + m.closed = true + if m.closer != nil { + return m.closer() + } + return nil +} + +// Required by quic.Stream interface, but not used in these tests +func (m *mockQuicStream) Read(p []byte) (n int, err error) { return 0, nil } +func (m *mockQuicStream) CancelRead(code quic.StreamErrorCode) {} +func (m *mockQuicStream) CancelWrite(code quic.StreamErrorCode) {} +func (m *mockQuicStream) SetReadDeadline(t time.Time) error { return nil } +func (m *mockQuicStream) SetWriteDeadline(t time.Time) error { return nil } +func (m *mockQuicStream) SetDeadline(t time.Time) error { return nil } +func (m *mockQuicStream) StreamID() quic.StreamID { return 0 } +func (m *mockQuicStream) Context() context.Context { return nil } + +func TestDoQWriter_Write(t *testing.T) { + tests := []struct { + name string + input []byte + streamWriter func(p []byte) (n int, err error) + expectErr bool + expectedData []byte + expectedN int + }{ + { + name: "successful write", + input: []byte{0x1, 0x2, 0x3}, + streamWriter: func(p []byte) (n int, err error) { + return len(p), nil + }, + expectErr: false, + expectedData: []byte{0x0, 0x3, 0x1, 0x2, 0x3}, // 3-byte length prefix + expectedN: 5, + }, + { + name: "stream write error", + input: []byte{0x4, 0x5}, + streamWriter: func(p []byte) (n int, err error) { + return 0, errors.New("stream error") + }, + expectErr: true, + expectedData: []byte{0x0, 0x2, 0x4, 0x5}, // 2-byte length prefix + expectedN: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockStream := &mockQuicStream{writer: tt.streamWriter} + writer := &DoQWriter{stream: mockStream} + + n, err := writer.Write(tt.input) + + if (err != nil) != tt.expectErr { + t.Errorf("Write() error = %v, expectErr %v", err, tt.expectErr) + return + } + if n != tt.expectedN { + t.Errorf("Write() n = %v, want %v", n, tt.expectedN) + } + + if !bytes.Equal(mockStream.data, tt.expectedData) { + t.Errorf("Write() data written to stream = %X, want %X", mockStream.data, tt.expectedData) + } + }) + } +} + +func TestDoQWriter_WriteMsg(t *testing.T) { + newMsg := func() *dns.Msg { + m := new(dns.Msg) + m.SetQuestion("example.com.", dns.TypeA) + return m + } + + tests := []struct { + name string + msg *dns.Msg + mockStream *mockQuicStream + expectErr bool + expectClosed bool + expectedData []byte // Expected data written to stream (packed msg with prefix) + packErr bool // Simulate error during msg.Pack() + }{ + { + name: "successful write and close", + msg: newMsg(), + mockStream: &mockQuicStream{}, + expectErr: false, + expectClosed: true, + }, + { + name: "msg.Pack() error", + msg: new(dns.Msg), + mockStream: &mockQuicStream{}, + expectErr: true, + packErr: true, // We'll make msg.Pack() fail by corrupting the msg or using a mock + expectClosed: false, // Close should not be called if Pack fails + }, + { + name: "stream write error", + msg: newMsg(), + mockStream: &mockQuicStream{ + writer: func(p []byte) (n int, err error) { + return 0, errors.New("stream write failed") + }, + }, + expectErr: true, + expectClosed: false, // Close should not be called if Write fails + }, + { + name: "stream close error", + msg: newMsg(), + mockStream: &mockQuicStream{ + closer: func() error { + return errors.New("stream close failed") + }, + }, + expectErr: true, + expectClosed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.packErr { + // Intentionally make the message invalid to cause a pack error. + // Invalid Rcode to ensure Pack fails. + tt.msg.Rcode = 1337 + } + + writer := &DoQWriter{stream: tt.mockStream, Msg: tt.msg} + err := writer.WriteMsg(tt.msg) + + if (err != nil) != tt.expectErr { + t.Errorf("WriteMsg() error = %v, expectErr %v", err, tt.expectErr) + } + + if tt.mockStream.closed != tt.expectClosed { + t.Errorf("WriteMsg() stream closed = %v, want %v", tt.mockStream.closed, tt.expectClosed) + } + + if tt.packErr { + if len(tt.mockStream.data) != 0 { + t.Errorf("WriteMsg() data written to stream on pack error = %X, want empty", tt.mockStream.data) + } + } + }) + } +} + +func TestDoQWriter_Close(t *testing.T) { + tests := []struct { + name string + mockStream *mockQuicStream + expectErr bool + }{ + { + name: "successful close", + mockStream: &mockQuicStream{}, + expectErr: false, + }, + { + name: "stream close error", + mockStream: &mockQuicStream{ + closer: func() error { + return errors.New("stream close error") + }, + }, + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + writer := &DoQWriter{stream: tt.mockStream} + err := writer.Close() + + if (err != nil) != tt.expectErr { + t.Errorf("Close() error = %v, expectErr %v", err, tt.expectErr) + } + if !tt.mockStream.closed { + t.Errorf("Close() stream not marked as closed") + } + }) + } +} diff --git a/core/dnsserver/register_test.go b/core/dnsserver/register_test.go index 9a4c7263f..b8d594f80 100644 --- a/core/dnsserver/register_test.go +++ b/core/dnsserver/register_test.go @@ -2,6 +2,8 @@ package dnsserver import ( "testing" + + "github.com/coredns/caddy/caddyfile" ) func TestHandler(t *testing.T) { @@ -118,3 +120,215 @@ func TestGroupingServers(t *testing.T) { } } } + +func TestInspectServerBlocks(t *testing.T) { + tests := []struct { + name string + serverBlocks []caddyfile.ServerBlock + expectedServerBlocks []caddyfile.ServerBlock + expectedConfigsLen int + expectedZoneAddrs map[string]zoneAddr + wantErr bool + }{ + { + name: "simple dns", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org.:53"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org.:53": {Zone: "example.org.", Port: "53", Transport: "dns"}, + }, + }, + { + name: "dns with port", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org:1053"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org.:1053"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org.:1053": {Zone: "example.org.", Port: "1053", Transport: "dns"}, + }, + }, + { + name: "tls", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"tls://example.org"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"tls://example.org.:853"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "tls://example.org.:853": {Zone: "example.org.", Port: "853", Transport: "tls"}, + }, + }, + { + name: "quic", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"quic://example.org"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"quic://example.org.:853"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "quic://example.org.:853": {Zone: "example.org.", Port: "853", Transport: "quic"}, + }, + }, + { + name: "grpc", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"grpc://example.org"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"grpc://example.org.:443"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "grpc://example.org.:443": {Zone: "example.org.", Port: "443", Transport: "grpc"}, + }, + }, + { + name: "https", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"https://example.org."}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"https://example.org.:443"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "https://example.org.:443": {Zone: "example.org.", Port: "443", Transport: "https"}, + }, + }, + { + name: "multiple hosts same key", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org,example.com:1053"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org,example.com.:1053"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org,example.com.:1053": {Zone: "example.org,example.com.", Port: "1053", Transport: "dns"}, + }, + }, + { + name: "multiple keys", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org", "example.com:1053"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org.:53", "dns://example.com.:1053"}}, + }, + expectedConfigsLen: 2, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org.:53": {Zone: "example.org.", Port: "53", Transport: "dns"}, + "dns://example.com.:1053": {Zone: "example.com.", Port: "1053", Transport: "dns"}, + }, + }, + { + name: "fqdn input", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org."}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org.:53"}}, + }, + expectedConfigsLen: 1, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org.:53": {Zone: "example.org.", Port: "53", Transport: "dns"}, + }, + }, + { + name: "multiple server blocks", + serverBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"example.org"}}, + {Keys: []string{"sub.example.org:1054"}}, + }, + expectedServerBlocks: []caddyfile.ServerBlock{ + {Keys: []string{"dns://example.org.:53"}}, + {Keys: []string{"dns://sub.example.org.:1054"}}, + }, + expectedConfigsLen: 2, + expectedZoneAddrs: map[string]zoneAddr{ + "dns://example.org.:53": {Zone: "example.org.", Port: "53", Transport: "dns"}, + "dns://sub.example.org.:1054": {Zone: "sub.example.org.", Port: "1054", Transport: "dns"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ctx := newContext(nil).(*dnsContext) + processedBlocks, err := ctx.InspectServerBlocks("TestInspectServerBlocks", tc.serverBlocks) + + if (err != nil) != tc.wantErr { + t.Fatalf("InspectServerBlocks() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.wantErr { + return + } + + if len(processedBlocks) != len(tc.expectedServerBlocks) { + t.Fatalf("Expected %d processed blocks, got %d", len(tc.expectedServerBlocks), len(processedBlocks)) + } + + for i, block := range processedBlocks { + expectedBlock := tc.expectedServerBlocks[i] + if len(block.Keys) != len(expectedBlock.Keys) { + t.Errorf("Block %d: expected %d keys, got %d. Expected: %v, Got: %v", i, len(expectedBlock.Keys), len(block.Keys), expectedBlock.Keys, block.Keys) + continue + } + for j, key := range block.Keys { + if key != expectedBlock.Keys[j] { + t.Errorf("Block %d, Key %d: expected key '%s', got '%s'", i, j, expectedBlock.Keys[j], key) + } + } + } + + if len(ctx.configs) != tc.expectedConfigsLen { + t.Errorf("Expected %d configs to be created, got %d", tc.expectedConfigsLen, len(ctx.configs)) + } + + if tc.expectedZoneAddrs != nil { + configIndex := 0 + for ib := range processedBlocks { + for ik, key := range processedBlocks[ib].Keys { + if configIndex >= len(ctx.configs) { + t.Fatalf("Not enough configs stored, expected at least %d, processed block %d key %d", configIndex+1, ib, ik) + } + cfg := ctx.configs[configIndex] + expectedZa, ok := tc.expectedZoneAddrs[key] + if !ok { + t.Errorf("No expected zoneAddr for processed key '%s'", key) + continue + } + + if cfg.Zone != expectedZa.Zone { + t.Errorf("Config for key '%s': expected Zone '%s', got '%s'", key, expectedZa.Zone, cfg.Zone) + } + if cfg.Port != expectedZa.Port { + t.Errorf("Config for key '%s': expected Port '%s', got '%s'", key, expectedZa.Port, cfg.Port) + } + if cfg.Transport != expectedZa.Transport { + t.Errorf("Config for key '%s': expected Transport '%s', got '%s'", key, expectedZa.Transport, cfg.Transport) + } + if len(cfg.ListenHosts) != 1 || cfg.ListenHosts[0] != "" { + t.Errorf("Config for key '%s': expected ListenHosts [''], got %v", key, cfg.ListenHosts) + } + configIndex++ + } + } + } + }) + } +}