mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-16 14:52:18 +01:00
net/dns/resolver: fix data race in test
Fixes #17339 Change-Id: I486d2a0e0931d701923c1e0f8efbda99510ab19b Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
1aaa1648c4
commit
bdb69d1b1f
@ -217,11 +217,12 @@ type resolverAndDelay struct {
|
|||||||
|
|
||||||
// forwarder forwards DNS packets to a number of upstream nameservers.
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
||||||
type forwarder struct {
|
type forwarder struct {
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
netMon *netmon.Monitor // always non-nil
|
netMon *netmon.Monitor // always non-nil
|
||||||
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
|
linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it
|
||||||
dialer *tsdial.Dialer
|
dialer *tsdial.Dialer
|
||||||
health *health.Tracker // always non-nil
|
health *health.Tracker // always non-nil
|
||||||
|
verboseFwd bool // if true, log all DNS forwarding
|
||||||
|
|
||||||
controlKnobs *controlknobs.Knobs // or nil
|
controlKnobs *controlknobs.Knobs // or nil
|
||||||
|
|
||||||
@ -258,6 +259,7 @@ func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkS
|
|||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
health: health,
|
health: health,
|
||||||
controlKnobs: knobs,
|
controlKnobs: knobs,
|
||||||
|
verboseFwd: verboseDNSForward(),
|
||||||
}
|
}
|
||||||
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background())
|
||||||
return f
|
return f
|
||||||
@ -515,7 +517,7 @@ var (
|
|||||||
//
|
//
|
||||||
// send expects the reply to have the same txid as txidOut.
|
// send expects the reply to have the same txid as txidOut.
|
||||||
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) {
|
||||||
if verboseDNSForward() {
|
if f.verboseFwd {
|
||||||
id := forwarderCount.Add(1)
|
id := forwarderCount.Add(1)
|
||||||
domain, typ, _ := nameFromQuery(fq.packet)
|
domain, typ, _ := nameFromQuery(fq.packet)
|
||||||
f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id)
|
f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id)
|
||||||
@ -978,7 +980,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
|||||||
}
|
}
|
||||||
defer fq.closeOnCtxDone.Close()
|
defer fq.closeOnCtxDone.Close()
|
||||||
|
|
||||||
if verboseDNSForward() {
|
if f.verboseFwd {
|
||||||
domainSha256 := sha256.Sum256([]byte(domain))
|
domainSha256 := sha256.Sum256([]byte(domain))
|
||||||
domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3])
|
domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3])
|
||||||
f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet))
|
f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet))
|
||||||
@ -1023,7 +1025,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
|||||||
metricDNSFwdErrorContext.Add(1)
|
metricDNSFwdErrorContext.Add(1)
|
||||||
return fmt.Errorf("waiting to send response: %w", ctx.Err())
|
return fmt.Errorf("waiting to send response: %w", ctx.Err())
|
||||||
case responseChan <- packet{v, query.family, query.addr}:
|
case responseChan <- packet{v, query.family, query.addr}:
|
||||||
if verboseDNSForward() {
|
if f.verboseFwd {
|
||||||
f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v))
|
f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v))
|
||||||
}
|
}
|
||||||
metricDNSFwdSuccess.Add(1)
|
metricDNSFwdSuccess.Add(1)
|
||||||
@ -1053,7 +1055,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
|
|||||||
}
|
}
|
||||||
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
|
||||||
case responseChan <- res:
|
case responseChan <- res:
|
||||||
if verboseDNSForward() {
|
if f.verboseFwd {
|
||||||
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
|
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@ -12,7 +12,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
@ -23,7 +22,6 @@ import (
|
|||||||
|
|
||||||
dns "golang.org/x/net/dns/dnsmessage"
|
dns "golang.org/x/net/dns/dnsmessage"
|
||||||
"tailscale.com/control/controlknobs"
|
"tailscale.com/control/controlknobs"
|
||||||
"tailscale.com/envknob"
|
|
||||||
"tailscale.com/health"
|
"tailscale.com/health"
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
"tailscale.com/net/tsdial"
|
"tailscale.com/net/tsdial"
|
||||||
@ -400,13 +398,6 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableDebug(tb testing.TB) {
|
|
||||||
const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND"
|
|
||||||
oldVal := os.Getenv(debugKnob)
|
|
||||||
envknob.Setenv(debugKnob, "true")
|
|
||||||
tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) })
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
|
func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) {
|
||||||
name := dns.MustNewName(domain)
|
name := dns.MustNewName(domain)
|
||||||
|
|
||||||
@ -554,9 +545,11 @@ func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), po
|
|||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwarderTCPFallback(t *testing.T) {
|
func beVerbose(f *forwarder) {
|
||||||
enableDebug(t)
|
f.verboseFwd = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwarderTCPFallback(t *testing.T) {
|
||||||
const domain = "large-dns-response.tailscale.com."
|
const domain = "large-dns-response.tailscale.com."
|
||||||
|
|
||||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||||
@ -576,7 +569,7 @@ func TestForwarderTCPFallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := mustRunTestQuery(t, request, nil, port)
|
resp := mustRunTestQuery(t, request, beVerbose, port)
|
||||||
if !bytes.Equal(resp, largeResponse) {
|
if !bytes.Equal(resp, largeResponse) {
|
||||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||||
}
|
}
|
||||||
@ -592,8 +585,6 @@ func TestForwarderTCPFallback(t *testing.T) {
|
|||||||
// Test to ensure that if the UDP listener is unresponsive, we always make a
|
// Test to ensure that if the UDP listener is unresponsive, we always make a
|
||||||
// TCP request even if we never get a response.
|
// TCP request even if we never get a response.
|
||||||
func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
||||||
enableDebug(t)
|
|
||||||
|
|
||||||
const domain = "large-dns-response.tailscale.com."
|
const domain = "large-dns-response.tailscale.com."
|
||||||
|
|
||||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||||
@ -614,7 +605,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := mustRunTestQuery(t, request, nil, port)
|
resp := mustRunTestQuery(t, request, beVerbose, port)
|
||||||
if !bytes.Equal(resp, largeResponse) {
|
if !bytes.Equal(resp, largeResponse) {
|
||||||
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse)
|
||||||
}
|
}
|
||||||
@ -624,8 +615,6 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
||||||
enableDebug(t)
|
|
||||||
|
|
||||||
const domain = "large-dns-response.tailscale.com."
|
const domain = "large-dns-response.tailscale.com."
|
||||||
|
|
||||||
// Make a response that's very large, containing a bunch of localhost addresses.
|
// Make a response that's very large, containing a bunch of localhost addresses.
|
||||||
@ -646,6 +635,7 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
|
resp := mustRunTestQuery(t, request, func(fwd *forwarder) {
|
||||||
|
fwd.verboseFwd = true
|
||||||
// Disable retries for this test.
|
// Disable retries for this test.
|
||||||
fwd.controlKnobs = &controlknobs.Knobs{}
|
fwd.controlKnobs = &controlknobs.Knobs{}
|
||||||
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true)
|
||||||
@ -668,8 +658,6 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) {
|
|||||||
|
|
||||||
// Test to ensure that we propagate DNS errors
|
// Test to ensure that we propagate DNS errors
|
||||||
func TestForwarderTCPFallbackError(t *testing.T) {
|
func TestForwarderTCPFallbackError(t *testing.T) {
|
||||||
enableDebug(t)
|
|
||||||
|
|
||||||
const domain = "error-response.tailscale.com."
|
const domain = "error-response.tailscale.com."
|
||||||
|
|
||||||
// Our response is a SERVFAIL
|
// Our response is a SERVFAIL
|
||||||
@ -686,7 +674,7 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
resp, err := runTestQuery(t, request, nil, port)
|
resp, err := runTestQuery(t, request, beVerbose, port)
|
||||||
if !sawRequest.Load() {
|
if !sawRequest.Load() {
|
||||||
t.Error("did not see DNS request")
|
t.Error("did not see DNS request")
|
||||||
}
|
}
|
||||||
@ -706,8 +694,6 @@ func TestForwarderTCPFallbackError(t *testing.T) {
|
|||||||
// Test to ensure that if we have more than one resolver, and at least one of them
|
// Test to ensure that if we have more than one resolver, and at least one of them
|
||||||
// returns a successful response, we propagate it.
|
// returns a successful response, we propagate it.
|
||||||
func TestForwarderWithManyResolvers(t *testing.T) {
|
func TestForwarderWithManyResolvers(t *testing.T) {
|
||||||
enableDebug(t)
|
|
||||||
|
|
||||||
const domain = "example.com."
|
const domain = "example.com."
|
||||||
request := makeTestRequest(t, domain)
|
request := makeTestRequest(t, domain)
|
||||||
|
|
||||||
@ -810,7 +796,7 @@ func TestForwarderWithManyResolvers(t *testing.T) {
|
|||||||
for i := range tt.responses {
|
for i := range tt.responses {
|
||||||
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
|
ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {})
|
||||||
}
|
}
|
||||||
gotResponse, err := runTestQuery(t, request, nil, ports...)
|
gotResponse, err := runTestQuery(t, request, beVerbose, ports...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("wanted nil, got %v", err)
|
t.Fatalf("wanted nil, got %v", err)
|
||||||
}
|
}
|
||||||
@ -869,7 +855,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) {
|
|||||||
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) {
|
||||||
})
|
})
|
||||||
|
|
||||||
res, err := runTestQuery(t, request, nil, port)
|
res, err := runTestQuery(t, request, beVerbose, port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user