diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index ed7ff78f7..3f586b60f 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "maps" "net" "net/http" "net/netip" @@ -41,8 +42,10 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" + "tailscale.com/util/mak" "tailscale.com/util/race" "tailscale.com/version" ) @@ -324,6 +327,19 @@ type forwarder struct { // resolver lookup. cloudHostFallback []resolverAndDelay + // schemes are the collection of registered URI scheme names that + // dynamically decide which resolver to use at the time of each query. The + // key is the scheme (the portion before the first `:`) and the value is a + // handler that determines where the current query should be sent. + // Use schemeCacheLocked() to get the current contents that can continue to + // be accessed once mu is released. This allows the (much more common) + // resolver code path to avoid repeated locking and unlocking. + // When modified, call invalidateSchemeCacheLocked() before unlocking mu. + schemes map[string]CustomSchemeHandler + // schemeCache is an immutable copy of schemes. Do not read directly, + // use schemeCacheLocked() which will regenerate its contents as needed. + schemeCache views.Map[string, CustomSchemeHandler] + // acceptDNS tracks the CorpDNS pref (--accept-dns) // This lets us skip health warnings if the forwarder receives inbound // queries directly - but we didn't configure it with any upstream resolvers. @@ -996,15 +1012,66 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } +// applySchemes resolves any custom-scheme entries in rrs using the provided +// scheme handlers, returning the resulting slice. Entries whose handler returns +// an error or empty string are dropped. Entries with no registered scheme pass +// through unchanged. If schemes is nil, rrs is returned as-is. +func applySchemes(logf logger.Logf, rrs []resolverAndDelay, schemes views.Map[string, CustomSchemeHandler]) []resolverAndDelay { + if schemes.IsNil() { + return rrs + } + var result []resolverAndDelay + for i, rr := range rrs { + scheme, _, hasColon := strings.Cut(rr.name.Addr, ":") + handler, isCustom := schemes.GetOk(scheme) + if !hasColon || !isCustom { + if result != nil { + result = append(result, rr) + } + continue + } + // Avoid making a results slice in the common case where there + // are no custom scheme resolvers. + if result == nil { + result = make([]resolverAndDelay, i, len(rrs)) + copy(result, rrs) + } + newAddr, err := handler(rr.name.Addr) + if err != nil { + logf("error from custom scheme handler, skipping resolver : %v", err) + } + if err != nil || newAddr == "" { + continue + } + newResolver := *rr.name + newResolver.Addr = newAddr + result = append(result, resolverAndDelay{name: &newResolver, startDelay: rr.startDelay}) + } + // If we didn't have any custom schemes, return the original rrs. + if result == nil { + return rrs + } + return result +} + // resolvers returns the resolvers to use for domain. func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() routes := f.routes cloudHostFallback := f.cloudHostFallback + schemes := f.schemeCacheLocked() f.mu.Unlock() + for _, route := range routes { - if route.Suffix == "." || route.Suffix.Contains(domain) { - return route.Resolvers + if route.Suffix != "." && !route.Suffix.Contains(domain) { + continue + } + resolved := applySchemes(f.logf, route.Resolvers, schemes) + // If scheme resolution filtered out all resolvers from a non-empty + // route, fall through to the next matching route. If the resolvers + // were configured to be empty allow resolved to be empty. + if len(resolved) > 0 || len(route.Resolvers) == 0 { + return resolved } } return cloudHostFallback // or nil if no fallback @@ -1021,6 +1088,39 @@ func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver return upstreamResolvers } +// RegisterCustomScheme adds a [CustomSchemeHandler] that is called to provide +// an updated address when a [dnstype.Resolver.Addr] uses that scheme. +func (f *forwarder) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + f.mu.Lock() + defer f.mu.Unlock() + if _, ok := f.schemes[scheme]; ok { + return fmt.Errorf("scheme %q already registered", scheme) + } + f.invalidateSchemeCacheLocked() + mak.Set(&f.schemes, scheme, h) + return nil +} + +// invalidateSchemeCacheLocked clears f.schemeCache so that it will be rebuilt +// on the next call to f.schemeCacheLocked(). +func (f *forwarder) invalidateSchemeCacheLocked() { + f.schemeCache = views.Map[string, CustomSchemeHandler]{} +} + +// schemeCacheLocked returns an immutable copy of f.schemes that can be used +// after mu is unlocked. +func (f *forwarder) schemeCacheLocked() views.Map[string, CustomSchemeHandler] { + if !f.schemeCache.IsNil() { + return f.schemeCache + } + if f.schemes == nil { + return f.schemeCache // returns a nil view + } + // Regenerate the cache + f.schemeCache = views.MapOf(maps.Clone(f.schemes)) + return f.schemeCache +} + // forwardQuery is information and state about a forwarded DNS query that's // being sent to 1 or more upstreams. // diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index faaaa9f3c..ebe4041a6 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" ) @@ -1385,3 +1386,142 @@ func TestForwarderHealthOnContextExpiry(t *testing.T) { }) } } + +func TestResolversCustomScheme(t *testing.T) { + t.Parallel() + tests := []struct { + name string + domain dnsname.FQDN + schemes map[string]CustomSchemeHandler + routes map[dnsname.FQDN][]*dnstype.Resolver + wantAddrs []string + }{ + { + name: "no-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{}, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "192.168.1.2:53"}, + }, + { + name: "single-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + }, + wantAddrs: []string{"1.2.3.4:53"}, + }, + { + name: "with-other-resolvers", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(key string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "1.2.3.4:53", "192.168.1.2:53"}, + }, + { + name: "multiple-custom-schemes", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "schemeOne": func(string) (string, error) { return "1.2.3.4:53", nil }, + "schemeTwo": func(string) (string, error) { return "5.6.7.8:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "schemeOne:customKey"}, + {Addr: "schemeTwo:customKey"}, + }, + }, + wantAddrs: []string{"1.2.3.4:53", "5.6.7.8:53"}, + }, + { + name: "empty-string-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + name: "error-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", fmt.Errorf("handler error") }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + // If the best-matching route yields no resolvers after scheme + // resolution, fall through to the next matching route. + name: "empty-scheme-result-falls-through-to-next-matching-route", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + ".": {{Addr: "192.168.1.1:53"}}, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil) + for scheme, handler := range tt.schemes { + if err := fwd.RegisterCustomScheme(scheme, handler); err != nil { + t.Fatal(err) + } + } + + fwd.setRoutes(tt.routes, false) + + got := fwd.resolvers(tt.domain) + var gotAddrs []string + for _, r := range got { + gotAddrs = append(gotAddrs, r.name.Addr) + } + if !slices.Equal(gotAddrs, tt.wantAddrs) { + t.Errorf("got %v, want %v", gotAddrs, tt.wantAddrs) + } + }) + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 01f0c8a63..4b2db5705 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -293,6 +293,18 @@ func (r *Resolver) SetConfig(cfg Config) error { return nil } +// CustomSchemeHandler takes a URI (retrieved from [dnstype.Resolver.Addr]) and +// returns an updated URI to use for the current query. The result is only valid +// for right now and may change over time. +type CustomSchemeHandler func(addr string) (newAddr string, err error) + +// RegisterCustomScheme adds a [CustomSchemaHandler] that is called to provide +// an updated address to the forwarder when a [dnstype.Resolver.Addr] uses that +// scheme. +func (r *Resolver) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + return r.forwarder.RegisterCustomScheme(scheme, h) +} + // Close shuts down the resolver and ensures poll goroutines have exited. // The Resolver cannot be used again after Close is called. func (r *Resolver) Close() {