mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-31 00:01:40 +01:00 
			
		
		
		
	Updates tailscale/coral#127 Change-Id: I2712c50630d0d1272c30305fa5a1899a19ffacef Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
		
			
				
	
	
		
			220 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			220 lines
		
	
	
		
			5.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| // Copyright (c) Tailscale Inc & AUTHORS
 | |
| // SPDX-License-Identifier: BSD-3-Clause
 | |
| 
 | |
| package main
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/binary"
 | |
| 	"encoding/json"
 | |
| 	"expvar"
 | |
| 	"log"
 | |
| 	"math/rand/v2"
 | |
| 	"net"
 | |
| 	"net/http"
 | |
| 	"net/netip"
 | |
| 	"strconv"
 | |
| 	"strings"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"tailscale.com/syncs"
 | |
| 	"tailscale.com/util/mak"
 | |
| 	"tailscale.com/util/slicesx"
 | |
| )
 | |
| 
 | |
| const refreshTimeout = time.Minute
 | |
| 
 | |
| type dnsEntryMap struct {
 | |
| 	IPs     map[string][]net.IP
 | |
| 	Percent map[string]float64 // "foo.com" => 0.5 for 50%
 | |
| }
 | |
| 
 | |
| var (
 | |
| 	dnsCache            atomic.Pointer[dnsEntryMap]
 | |
| 	dnsCacheBytes       syncs.AtomicValue[[]byte] // of JSON
 | |
| 	unpublishedDNSCache atomic.Pointer[dnsEntryMap]
 | |
| 	bootstrapLookupMap  syncs.Map[string, bool]
 | |
| )
 | |
| 
 | |
| var (
 | |
| 	bootstrapDNSRequests        = expvar.NewInt("counter_bootstrap_dns_requests")
 | |
| 	publishedDNSHits            = expvar.NewInt("counter_bootstrap_dns_published_hits")
 | |
| 	publishedDNSMisses          = expvar.NewInt("counter_bootstrap_dns_published_misses")
 | |
| 	unpublishedDNSHits          = expvar.NewInt("counter_bootstrap_dns_unpublished_hits")
 | |
| 	unpublishedDNSMisses        = expvar.NewInt("counter_bootstrap_dns_unpublished_misses")
 | |
| 	unpublishedDNSPercentMisses = expvar.NewInt("counter_bootstrap_dns_unpublished_percent_misses")
 | |
| )
 | |
| 
 | |
| func init() {
 | |
| 	expvar.Publish("counter_bootstrap_dns_queried_domains", expvar.Func(func() any {
 | |
| 		return bootstrapLookupMap.Len()
 | |
| 	}))
 | |
| }
 | |
| 
 | |
| func refreshBootstrapDNSLoop() {
 | |
| 	if *bootstrapDNS == "" && *unpublishedDNS == "" {
 | |
| 		return
 | |
| 	}
 | |
| 	for {
 | |
| 		refreshBootstrapDNS()
 | |
| 		refreshUnpublishedDNS()
 | |
| 		time.Sleep(10 * time.Minute)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func refreshBootstrapDNS() {
 | |
| 	if *bootstrapDNS == "" {
 | |
| 		return
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
 | |
| 	defer cancel()
 | |
| 	dnsEntries := resolveList(ctx, *bootstrapDNS)
 | |
| 	// Randomize the order of the IPs for each name to avoid the client biasing
 | |
| 	// to IPv6
 | |
| 	for _, vv := range dnsEntries.IPs {
 | |
| 		slicesx.Shuffle(vv)
 | |
| 	}
 | |
| 	j, err := json.MarshalIndent(dnsEntries.IPs, "", "\t")
 | |
| 	if err != nil {
 | |
| 		// leave the old values in place
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	dnsCache.Store(dnsEntries)
 | |
| 	dnsCacheBytes.Store(j)
 | |
| }
 | |
| 
 | |
| func refreshUnpublishedDNS() {
 | |
| 	if *unpublishedDNS == "" {
 | |
| 		return
 | |
| 	}
 | |
| 	ctx, cancel := context.WithTimeout(context.Background(), refreshTimeout)
 | |
| 	defer cancel()
 | |
| 	dnsEntries := resolveList(ctx, *unpublishedDNS)
 | |
| 	unpublishedDNSCache.Store(dnsEntries)
 | |
| }
 | |
| 
 | |
| // resolveList takes a comma-separated list of DNS names to resolve.
 | |
| //
 | |
| // If an entry contains a slash, it's two DNS names: the first is the one to
 | |
| // resolve and the second is that of a TXT recording containing the rollout
 | |
| // percentage in range "0".."100". If the TXT record doesn't exist or is
 | |
| // malformed, the percentage is 0. If the TXT record is not provided (there's no
 | |
| // slash), then the percentage is 100.
 | |
| func resolveList(ctx context.Context, list string) *dnsEntryMap {
 | |
| 	ents := strings.Split(list, ",")
 | |
| 
 | |
| 	ret := &dnsEntryMap{}
 | |
| 
 | |
| 	var r net.Resolver
 | |
| 	for _, ent := range ents {
 | |
| 		name, txtName, _ := strings.Cut(ent, "/")
 | |
| 		addrs, err := r.LookupIP(ctx, "ip", name)
 | |
| 		if err != nil {
 | |
| 			log.Printf("bootstrap DNS lookup %q: %v", name, err)
 | |
| 			continue
 | |
| 		}
 | |
| 		mak.Set(&ret.IPs, name, addrs)
 | |
| 
 | |
| 		if txtName == "" {
 | |
| 			mak.Set(&ret.Percent, name, 1.0)
 | |
| 			continue
 | |
| 		}
 | |
| 		vals, err := r.LookupTXT(ctx, txtName)
 | |
| 		if err != nil {
 | |
| 			log.Printf("bootstrap DNS lookup %q: %v", txtName, err)
 | |
| 			continue
 | |
| 		}
 | |
| 		for _, v := range vals {
 | |
| 			if v, err := strconv.Atoi(v); err == nil && v >= 0 && v <= 100 {
 | |
| 				mak.Set(&ret.Percent, name, float64(v)/100)
 | |
| 			}
 | |
| 		}
 | |
| 	}
 | |
| 	return ret
 | |
| }
 | |
| 
 | |
| func handleBootstrapDNS(w http.ResponseWriter, r *http.Request) {
 | |
| 	bootstrapDNSRequests.Add(1)
 | |
| 
 | |
| 	w.Header().Set("Content-Type", "application/json")
 | |
| 	// Bootstrap DNS requests occur cross-regions, and are randomized per
 | |
| 	// request, so keeping a connection open is pointlessly expensive.
 | |
| 	w.Header().Set("Connection", "close")
 | |
| 
 | |
| 	// Try answering a query from our hidden map first
 | |
| 	if q := r.URL.Query().Get("q"); q != "" {
 | |
| 		bootstrapLookupMap.Store(q, true)
 | |
| 		if bootstrapLookupMap.Len() > 500 { // defensive
 | |
| 			bootstrapLookupMap.Clear()
 | |
| 		}
 | |
| 		if m := unpublishedDNSCache.Load(); m != nil && len(m.IPs[q]) > 0 {
 | |
| 			unpublishedDNSHits.Add(1)
 | |
| 
 | |
| 			percent := m.Percent[q]
 | |
| 			if remoteAddrMatchesPercent(r.RemoteAddr, percent) {
 | |
| 				// Only return the specific query, not everything.
 | |
| 				m := map[string][]net.IP{q: m.IPs[q]}
 | |
| 				j, err := json.MarshalIndent(m, "", "\t")
 | |
| 				if err == nil {
 | |
| 					w.Write(j)
 | |
| 					return
 | |
| 				}
 | |
| 			} else {
 | |
| 				unpublishedDNSPercentMisses.Add(1)
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		// If we have a "q" query for a name in the published cache
 | |
| 		// list, then track whether that's a hit/miss.
 | |
| 		m := dnsCache.Load()
 | |
| 		var inPub bool
 | |
| 		var ips []net.IP
 | |
| 		if m != nil {
 | |
| 			ips, inPub = m.IPs[q]
 | |
| 		}
 | |
| 		if inPub {
 | |
| 			if len(ips) > 0 {
 | |
| 				publishedDNSHits.Add(1)
 | |
| 			} else {
 | |
| 				publishedDNSMisses.Add(1)
 | |
| 			}
 | |
| 		} else {
 | |
| 			// If it wasn't in either cache, treat this as a query
 | |
| 			// for the unpublished cache, and thus a cache miss.
 | |
| 			unpublishedDNSMisses.Add(1)
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// Fall back to returning the public set of cached DNS names
 | |
| 	j := dnsCacheBytes.Load()
 | |
| 	w.Write(j)
 | |
| }
 | |
| 
 | |
| // percent is [0.0, 1.0].
 | |
| func remoteAddrMatchesPercent(remoteAddr string, percent float64) bool {
 | |
| 	if percent == 0 {
 | |
| 		return false
 | |
| 	}
 | |
| 	if percent == 1 {
 | |
| 		return true
 | |
| 	}
 | |
| 	reqIPStr, _, err := net.SplitHostPort(remoteAddr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	reqIP, err := netip.ParseAddr(reqIPStr)
 | |
| 	if err != nil {
 | |
| 		return false
 | |
| 	}
 | |
| 	if reqIP.IsLoopback() {
 | |
| 		// For local testing.
 | |
| 		return rand.Float64() < 0.5
 | |
| 	}
 | |
| 	reqIP16 := reqIP.As16()
 | |
| 	rndSrc := rand.NewPCG(binary.LittleEndian.Uint64(reqIP16[:8]), binary.LittleEndian.Uint64(reqIP16[8:]))
 | |
| 	rnd := rand.New(rndSrc)
 | |
| 	return percent > rnd.Float64()
 | |
| }
 |