net/dns: add custom scheme resolvers

If another part of the client code registers a custom scheme with the
forwarder, the forwarder will check resolver addresses to see if they
match the scheme. If they do, the corresponding custom scheme handler
will be called to find the actual address for the resolver at this
moment. If the handler returns the empty string then that resolver will
be ignored.

This is useful if you want to dynamically determine where to send
certain DNS requests. It is being added to support new app connector
(conn25) work that would like to make sure it sends DNS requests to the
current connector peer in a high availability configuration.

Updates tailscale/corp#39858

Signed-off-by: Fran Bull <fran@tailscale.com>
This commit is contained in:
Fran Bull 2026-04-29 13:46:22 -07:00 committed by franbull
parent 78126c5d9f
commit bdf3419e7d
3 changed files with 254 additions and 2 deletions

View File

@ -13,6 +13,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"net"
"net/http"
"net/netip"
@ -41,8 +42,10 @@ import (
"tailscale.com/types/dnstype"
"tailscale.com/types/logger"
"tailscale.com/types/nettype"
"tailscale.com/types/views"
"tailscale.com/util/cloudenv"
"tailscale.com/util/dnsname"
"tailscale.com/util/mak"
"tailscale.com/util/race"
"tailscale.com/version"
)
@ -324,6 +327,19 @@ type forwarder struct {
// resolver lookup.
cloudHostFallback []resolverAndDelay
// schemes are the collection of registered URI scheme names that
// dynamically decide which resolver to use at the time of each query. The
// key is the scheme (the portion before the first `:`) and the value is a
// handler that determines where the current query should be sent.
// Use schemeCacheLocked() to get the current contents that can continue to
// be accessed once mu is released. This allows the (much more common)
// resolver code path to avoid repeated locking and unlocking.
// When modified, call invalidateSchemeCacheLocked() before unlocking mu.
schemes map[string]CustomSchemeHandler
// schemeCache is an immutable copy of schemes. Do not read directly,
// use schemeCacheLocked() which will regenerate its contents as needed.
schemeCache views.Map[string, CustomSchemeHandler]
// acceptDNS tracks the CorpDNS pref (--accept-dns)
// This lets us skip health warnings if the forwarder receives inbound
// queries directly - but we didn't configure it with any upstream resolvers.
@ -996,15 +1012,66 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn
return out, nil
}
// applySchemes resolves any custom-scheme entries in rrs using the provided
// scheme handlers, returning the resulting slice. Entries whose handler returns
// an error or empty string are dropped. Entries with no registered scheme pass
// through unchanged. If schemes is nil, rrs is returned as-is.
func applySchemes(logf logger.Logf, rrs []resolverAndDelay, schemes views.Map[string, CustomSchemeHandler]) []resolverAndDelay {
if schemes.IsNil() {
return rrs
}
var result []resolverAndDelay
for i, rr := range rrs {
scheme, _, hasColon := strings.Cut(rr.name.Addr, ":")
handler, isCustom := schemes.GetOk(scheme)
if !hasColon || !isCustom {
if result != nil {
result = append(result, rr)
}
continue
}
// Avoid making a results slice in the common case where there
// are no custom scheme resolvers.
if result == nil {
result = make([]resolverAndDelay, i, len(rrs))
copy(result, rrs)
}
newAddr, err := handler(rr.name.Addr)
if err != nil {
logf("error from custom scheme handler, skipping resolver : %v", err)
}
if err != nil || newAddr == "" {
continue
}
newResolver := *rr.name
newResolver.Addr = newAddr
result = append(result, resolverAndDelay{name: &newResolver, startDelay: rr.startDelay})
}
// If we didn't have any custom schemes, return the original rrs.
if result == nil {
return rrs
}
return result
}
// resolvers returns the resolvers to use for domain.
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
f.mu.Lock()
routes := f.routes
cloudHostFallback := f.cloudHostFallback
schemes := f.schemeCacheLocked()
f.mu.Unlock()
for _, route := range routes {
if route.Suffix == "." || route.Suffix.Contains(domain) {
return route.Resolvers
if route.Suffix != "." && !route.Suffix.Contains(domain) {
continue
}
resolved := applySchemes(f.logf, route.Resolvers, schemes)
// If scheme resolution filtered out all resolvers from a non-empty
// route, fall through to the next matching route. If the resolvers
// were configured to be empty allow resolved to be empty.
if len(resolved) > 0 || len(route.Resolvers) == 0 {
return resolved
}
}
return cloudHostFallback // or nil if no fallback
@ -1021,6 +1088,39 @@ func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver
return upstreamResolvers
}
// RegisterCustomScheme adds a [CustomSchemeHandler] that is called to provide
// an updated address when a [dnstype.Resolver.Addr] uses that scheme.
func (f *forwarder) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error {
f.mu.Lock()
defer f.mu.Unlock()
if _, ok := f.schemes[scheme]; ok {
return fmt.Errorf("scheme %q already registered", scheme)
}
f.invalidateSchemeCacheLocked()
mak.Set(&f.schemes, scheme, h)
return nil
}
// invalidateSchemeCacheLocked clears f.schemeCache so that it will be rebuilt
// on the next call to f.schemeCacheLocked().
func (f *forwarder) invalidateSchemeCacheLocked() {
f.schemeCache = views.Map[string, CustomSchemeHandler]{}
}
// schemeCacheLocked returns an immutable copy of f.schemes that can be used
// after mu is unlocked.
func (f *forwarder) schemeCacheLocked() views.Map[string, CustomSchemeHandler] {
if !f.schemeCache.IsNil() {
return f.schemeCache
}
if f.schemes == nil {
return f.schemeCache // returns a nil view
}
// Regenerate the cache
f.schemeCache = views.MapOf(maps.Clone(f.schemes))
return f.schemeCache
}
// forwardQuery is information and state about a forwarded DNS query that's
// being sent to 1 or more upstreams.
//

View File

@ -27,6 +27,7 @@ import (
"tailscale.com/net/tsdial"
"tailscale.com/tstest"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname"
"tailscale.com/util/eventbus/eventbustest"
)
@ -1385,3 +1386,142 @@ func TestForwarderHealthOnContextExpiry(t *testing.T) {
})
}
}
func TestResolversCustomScheme(t *testing.T) {
t.Parallel()
tests := []struct {
name string
domain dnsname.FQDN
schemes map[string]CustomSchemeHandler
routes map[dnsname.FQDN][]*dnstype.Resolver
wantAddrs []string
}{
{
name: "no-custom-scheme",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {
{Addr: "192.168.1.1:53"},
{Addr: "192.168.1.2:53"},
},
},
wantAddrs: []string{"192.168.1.1:53", "192.168.1.2:53"},
},
{
name: "single-custom-scheme",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"myscheme": func(string) (string, error) { return "1.2.3.4:53", nil },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {{Addr: "myscheme:customKey"}},
},
wantAddrs: []string{"1.2.3.4:53"},
},
{
name: "with-other-resolvers",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"myscheme": func(key string) (string, error) { return "1.2.3.4:53", nil },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {
{Addr: "192.168.1.1:53"},
{Addr: "myscheme:customKey"},
{Addr: "192.168.1.2:53"},
},
},
wantAddrs: []string{"192.168.1.1:53", "1.2.3.4:53", "192.168.1.2:53"},
},
{
name: "multiple-custom-schemes",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"schemeOne": func(string) (string, error) { return "1.2.3.4:53", nil },
"schemeTwo": func(string) (string, error) { return "5.6.7.8:53", nil },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {
{Addr: "schemeOne:customKey"},
{Addr: "schemeTwo:customKey"},
},
},
wantAddrs: []string{"1.2.3.4:53", "5.6.7.8:53"},
},
{
name: "empty-string-means-no-resolver",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"myscheme": func(string) (string, error) { return "", nil },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {
{Addr: "192.168.1.1:53"},
{Addr: "myscheme:customKey"},
},
},
wantAddrs: []string{"192.168.1.1:53"},
},
{
name: "error-means-no-resolver",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"myscheme": func(string) (string, error) { return "", fmt.Errorf("handler error") },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {
{Addr: "192.168.1.1:53"},
{Addr: "myscheme:customKey"},
},
},
wantAddrs: []string{"192.168.1.1:53"},
},
{
// If the best-matching route yields no resolvers after scheme
// resolution, fall through to the next matching route.
name: "empty-scheme-result-falls-through-to-next-matching-route",
domain: "example.com.",
schemes: map[string]CustomSchemeHandler{
"myscheme": func(string) (string, error) { return "", nil },
},
routes: map[dnsname.FQDN][]*dnstype.Resolver{
"example.com.": {{Addr: "myscheme:customKey"}},
".": {{Addr: "192.168.1.1:53"}},
},
wantAddrs: []string{"192.168.1.1:53"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logf := tstest.WhileTestRunningLogger(t)
bus := eventbustest.NewBus(t)
netMon, err := netmon.New(bus, logf)
if err != nil {
t.Fatal(err)
}
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
dialer.SetBus(bus)
fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil)
for scheme, handler := range tt.schemes {
if err := fwd.RegisterCustomScheme(scheme, handler); err != nil {
t.Fatal(err)
}
}
fwd.setRoutes(tt.routes, false)
got := fwd.resolvers(tt.domain)
var gotAddrs []string
for _, r := range got {
gotAddrs = append(gotAddrs, r.name.Addr)
}
if !slices.Equal(gotAddrs, tt.wantAddrs) {
t.Errorf("got %v, want %v", gotAddrs, tt.wantAddrs)
}
})
}
}

View File

@ -293,6 +293,18 @@ func (r *Resolver) SetConfig(cfg Config) error {
return nil
}
// CustomSchemeHandler takes a URI (retrieved from [dnstype.Resolver.Addr]) and
// returns an updated URI to use for the current query. The result is only valid
// for right now and may change over time.
type CustomSchemeHandler func(addr string) (newAddr string, err error)
// RegisterCustomScheme adds a [CustomSchemaHandler] that is called to provide
// an updated address to the forwarder when a [dnstype.Resolver.Addr] uses that
// scheme.
func (r *Resolver) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error {
return r.forwarder.RegisterCustomScheme(scheme, h)
}
// Close shuts down the resolver and ensures poll goroutines have exited.
// The Resolver cannot be used again after Close is called.
func (r *Resolver) Close() {