mirror of
https://github.com/tailscale/tailscale.git
synced 2025-09-24 07:01:19 +02:00
control/controlhttp: simplify, fix race dialing, remove priority concept
controlhttp has the responsibility of dialing a set of candidate control endpoints in a way that minimizes user facing latency. If one control endpoint is unavailable we promptly dial another, racing across the dimensions of: IPv6, IPv4, port 80, and port 443, over multiple server endpoints. In the case that the top priority endpoint was not available, the prior implementation would hang waiting for other results, so as to try to return the highest priority successful connection to the rest of the client code. This hang would take too long with a large dialplan and sufficient client to endpoint latency as to cause the server to timeout the connection due to inactivity in the intermediate state. Instead of trying to prioritize non-ideal candidate connections, the first successful connection is now used unconditionally, improving user facing latency and avoiding any delays that would encroach on the server-side timeout. The tests are converted to memnet and synctest, running on all platforms. Fixes #8442 Fixes tailscale/corp#32534 Co-authored-by: James Tucker <james@tailscale.com> Change-Id: I4eb57f046d8b40403220e40eb67a31c41adb3a38 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
parent
1b6bc37f28
commit
db048e905d
@ -186,7 +186,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep
|
|||||||
tailscale.com/util/lineiter from tailscale.com/hostinfo+
|
tailscale.com/util/lineiter from tailscale.com/hostinfo+
|
||||||
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
L tailscale.com/util/linuxfw from tailscale.com/net/netns
|
||||||
tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+
|
tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+
|
||||||
tailscale.com/util/multierr from tailscale.com/control/controlhttp+
|
tailscale.com/util/multierr from tailscale.com/health+
|
||||||
tailscale.com/util/must from tailscale.com/clientupdate/distsign+
|
tailscale.com/util/must from tailscale.com/clientupdate/distsign+
|
||||||
tailscale.com/util/nocasemaps from tailscale.com/types/ipproto
|
tailscale.com/util/nocasemaps from tailscale.com/types/ipproto
|
||||||
tailscale.com/util/prompt from tailscale.com/cmd/tailscale/cli
|
tailscale.com/util/prompt from tailscale.com/cmd/tailscale/cli
|
||||||
|
@ -27,14 +27,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptrace"
|
"net/http/httptrace"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sort"
|
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -53,7 +51,6 @@ import (
|
|||||||
"tailscale.com/syncs"
|
"tailscale.com/syncs"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/tstime"
|
"tailscale.com/tstime"
|
||||||
"tailscale.com/util/multierr"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var stdDialer net.Dialer
|
var stdDialer net.Dialer
|
||||||
@ -110,18 +107,8 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
|
|||||||
}
|
}
|
||||||
candidates := a.DialPlan.Candidates
|
candidates := a.DialPlan.Candidates
|
||||||
|
|
||||||
// Otherwise, we try dialing per the plan. Store the highest priority
|
// Create a context to be canceled as we return, so once we get a good connection,
|
||||||
// in the list, so that if we get a connection to one of those
|
// we can drop all the other ones.
|
||||||
// candidates we can return quickly.
|
|
||||||
var highestPriority int = math.MinInt
|
|
||||||
for _, c := range candidates {
|
|
||||||
if c.Priority > highestPriority {
|
|
||||||
highestPriority = c.Priority
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This context allows us to cancel in-flight connections if we get a
|
|
||||||
// highest-priority connection before we're all done.
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -129,142 +116,58 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) {
|
|||||||
type dialResult struct {
|
type dialResult struct {
|
||||||
conn *ClientConn
|
conn *ClientConn
|
||||||
err error
|
err error
|
||||||
cand tailcfg.ControlIPCandidate
|
|
||||||
}
|
}
|
||||||
resultsCh := make(chan dialResult, len(candidates))
|
resultsCh := make(chan dialResult) // unbuffered, never closed
|
||||||
|
|
||||||
var pending atomic.Int32
|
dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) {
|
||||||
pending.Store(int32(len(candidates)))
|
if cand.ACEHost != "" {
|
||||||
for _, c := range candidates {
|
a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q via ACE %s (%s)", cand.DialStartDelaySec, a.Hostname, cand.ACEHost, cmp.Or(cand.IP.String(), "dns"))
|
||||||
go func(ctx context.Context, c tailcfg.ControlIPCandidate) {
|
} else {
|
||||||
var (
|
a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %s", cand.DialStartDelaySec, a.Hostname, cand.IP.String())
|
||||||
conn *ClientConn
|
}
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
// Always send results back to our channel.
|
ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second)))
|
||||||
defer func() {
|
defer cancel()
|
||||||
resultsCh <- dialResult{conn, err, c}
|
return a.dialHostOpt(ctx, cand.IP, cand.ACEHost)
|
||||||
if pending.Add(-1) == 0 {
|
|
||||||
close(resultsCh)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// If non-zero, wait the configured start timeout
|
|
||||||
// before we do anything.
|
|
||||||
if c.DialStartDelaySec > 0 {
|
|
||||||
a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP)
|
|
||||||
tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second)))
|
|
||||||
defer tmr.Stop()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
err = ctx.Err()
|
|
||||||
return
|
|
||||||
case <-tmrChannel:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, create a sub-context with the given timeout and
|
|
||||||
// try dialing the provided host.
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second)))
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
if c.IP.IsValid() {
|
|
||||||
a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP)
|
|
||||||
} else if c.ACEHost != "" {
|
|
||||||
a.logf("[v2] controlhttp: trying to dial %q via ACE %q", a.Hostname, c.ACEHost)
|
|
||||||
}
|
|
||||||
// This will dial, and the defer above sends it back to our parent.
|
|
||||||
conn, err = a.dialHostOpt(ctx, c.IP, c.ACEHost)
|
|
||||||
}(ctx, c)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var results []dialResult
|
for _, cand := range candidates {
|
||||||
for res := range resultsCh {
|
timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() {
|
||||||
// If we get a response that has the highest priority, we don't
|
|
||||||
// need to wait for any of the other connections to finish; we
|
|
||||||
// can just return this connection.
|
|
||||||
//
|
|
||||||
// TODO(andrew): we could make this better by keeping track of
|
|
||||||
// the highest remaining priority dynamically, instead of just
|
|
||||||
// checking for the highest total
|
|
||||||
if res.cand.Priority == highestPriority && res.conn != nil {
|
|
||||||
a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, cmp.Or(res.cand.ACEHost, res.cand.IP.String()))
|
|
||||||
|
|
||||||
// Drain the channel and any existing connections in
|
|
||||||
// the background.
|
|
||||||
go func() {
|
go func() {
|
||||||
for _, res := range results {
|
conn, err := dialCand(cand)
|
||||||
if res.conn != nil {
|
select {
|
||||||
res.conn.Close()
|
case resultsCh <- dialResult{conn, err}:
|
||||||
|
if err == nil {
|
||||||
|
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String()))
|
||||||
}
|
}
|
||||||
}
|
case <-ctx.Done():
|
||||||
for res := range resultsCh {
|
if conn != nil {
|
||||||
if res.conn != nil {
|
conn.Close()
|
||||||
res.conn.Close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if a.drainFinished != nil {
|
|
||||||
close(a.drainFinished)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
return res.conn, nil
|
})
|
||||||
}
|
defer timer.Stop()
|
||||||
|
|
||||||
// This isn't a highest-priority result, so just store it until
|
|
||||||
// we're done.
|
|
||||||
results = append(results, res)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// After we finish this function, close any remaining open connections.
|
var errs []error
|
||||||
defer func() {
|
for {
|
||||||
for _, result := range results {
|
select {
|
||||||
// Note: below, we nil out the returned connection (if
|
case res := <-resultsCh:
|
||||||
// any) in the slice so we don't close it.
|
if res.err == nil {
|
||||||
if result.conn != nil {
|
return res.conn, nil
|
||||||
result.conn.Close()
|
|
||||||
}
|
}
|
||||||
|
errs = append(errs, res.err)
|
||||||
|
if len(errs) == len(candidates) {
|
||||||
|
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
|
||||||
|
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...))
|
||||||
|
return a.dialHost(ctx)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
a.logf("controlhttp: context aborted dialing")
|
||||||
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// We don't drain asynchronously after this point, so notify our
|
|
||||||
// channel when we return.
|
|
||||||
if a.drainFinished != nil {
|
|
||||||
close(a.drainFinished)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Sort by priority, then take the first non-error response.
|
|
||||||
sort.Slice(results, func(i, j int) bool {
|
|
||||||
// NOTE: intentionally inverted so that the highest priority
|
|
||||||
// item comes first
|
|
||||||
return results[i].cand.Priority > results[j].cand.Priority
|
|
||||||
})
|
|
||||||
|
|
||||||
var (
|
|
||||||
conn *ClientConn
|
|
||||||
errs []error
|
|
||||||
)
|
|
||||||
for i, result := range results {
|
|
||||||
if result.err != nil {
|
|
||||||
errs = append(errs, result.err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(result.cand.ACEHost, result.cand.IP.String()))
|
|
||||||
conn = result.conn
|
|
||||||
results[i].conn = nil // so we don't close it in the defer
|
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
if ctx.Err() != nil {
|
|
||||||
a.logf("controlhttp: context aborted dialing")
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
merr := multierr.New(errs...)
|
|
||||||
|
|
||||||
// If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS.
|
|
||||||
a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error())
|
|
||||||
return a.dialHost(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
|
// The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to
|
||||||
@ -402,6 +305,9 @@ func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err80, err443 error
|
var err80, err443 error
|
||||||
|
if forceTLS {
|
||||||
|
err80 = errors.New("TLS forced: no port 80 dialed")
|
||||||
|
}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
@ -98,7 +98,6 @@ type Dialer struct {
|
|||||||
logPort80Failure atomic.Bool
|
logPort80Failure atomic.Bool
|
||||||
|
|
||||||
// For tests only
|
// For tests only
|
||||||
drainFinished chan struct{}
|
|
||||||
omitCertErrorLogging bool
|
omitCertErrorLogging bool
|
||||||
testFallbackDelay time.Duration
|
testFallbackDelay time.Duration
|
||||||
|
|
||||||
|
@ -15,19 +15,20 @@ import (
|
|||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"testing/synctest"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"tailscale.com/control/controlbase"
|
"tailscale.com/control/controlbase"
|
||||||
"tailscale.com/control/controlhttp/controlhttpcommon"
|
"tailscale.com/control/controlhttp/controlhttpcommon"
|
||||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||||
"tailscale.com/health"
|
"tailscale.com/health"
|
||||||
|
"tailscale.com/net/memnet"
|
||||||
"tailscale.com/net/netmon"
|
"tailscale.com/net/netmon"
|
||||||
"tailscale.com/net/netx"
|
|
||||||
"tailscale.com/net/socks5"
|
"tailscale.com/net/socks5"
|
||||||
"tailscale.com/net/tsdial"
|
"tailscale.com/net/tsdial"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
@ -36,6 +37,7 @@ import (
|
|||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/logger"
|
"tailscale.com/types/logger"
|
||||||
"tailscale.com/util/eventbus/eventbustest"
|
"tailscale.com/util/eventbus/eventbustest"
|
||||||
|
"tailscale.com/util/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
type httpTestParam struct {
|
type httpTestParam struct {
|
||||||
@ -532,6 +534,28 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// slowListener wraps a memnet listener to delay accept operations
|
||||||
|
type slowListener struct {
|
||||||
|
net.Listener
|
||||||
|
delay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sl *slowListener) Accept() (net.Conn, error) {
|
||||||
|
// Add delay before accepting connections
|
||||||
|
timer := time.NewTimer(sl.delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
<-timer.C
|
||||||
|
|
||||||
|
return sl.Listener.Accept()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSlowListener(inner net.Listener, delay time.Duration) net.Listener {
|
||||||
|
return &slowListener{
|
||||||
|
Listener: inner,
|
||||||
|
delay: delay,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
|
func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue)
|
w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue)
|
||||||
@ -545,33 +569,102 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestDialPlan(t *testing.T) {
|
func TestDialPlan(t *testing.T) {
|
||||||
if runtime.GOOS != "linux" {
|
testCases := []struct {
|
||||||
t.Skip("only works on Linux due to multiple localhost addresses")
|
name string
|
||||||
|
plan *tailcfg.ControlDialPlan
|
||||||
|
want []netip.Addr
|
||||||
|
allowFallback bool
|
||||||
|
maxDuration time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "broken-then-good",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10, DialStartDelaySec: 1},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple-candidates-with-broken",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
// Multiple good IPs plus a broken one
|
||||||
|
// Should succeed with any of the good ones
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.4"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.4"), netip.MustParseAddr("10.0.0.3")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple-candidates-race",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.3"), netip.MustParseAddr("10.0.0.2")},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fallback",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 1},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.1")},
|
||||||
|
allowFallback: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// In tailscale/corp#32534 we discovered that a prior implementation
|
||||||
|
// of the dial race was waiting for all dials to complete when the
|
||||||
|
// top priority dial was failing. This delay was long enough that in
|
||||||
|
// real scenarios the server will close the connection due to
|
||||||
|
// inactivity, because the client does not send the first inside of
|
||||||
|
// noise request soon enough. This test is a regression guard
|
||||||
|
// against that behavior - proving that the dial returns promptly
|
||||||
|
// even if there is some cause of a slow race.
|
||||||
|
name: "slow-endpoint-doesnt-block",
|
||||||
|
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.12"), Priority: 5, DialTimeoutSec: 10},
|
||||||
|
{IP: netip.MustParseAddr("10.0.0.2"), Priority: 1, DialTimeoutSec: 10},
|
||||||
|
}},
|
||||||
|
want: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
|
||||||
|
maxDuration: 2 * time.Second, // Must complete quickly, not wait for slow endpoint
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, tt := range testCases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
synctest.Test(t, func(t *testing.T) {
|
||||||
|
runDialPlanTest(t, tt.plan, tt.want, tt.allowFallback, tt.maxDuration)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runDialPlanTest(t *testing.T, plan *tailcfg.ControlDialPlan, want []netip.Addr, allowFallback bool, maxDuration time.Duration) {
|
||||||
client, server := key.NewMachine(), key.NewMachine()
|
client, server := key.NewMachine(), key.NewMachine()
|
||||||
|
|
||||||
const (
|
const (
|
||||||
testProtocolVersion = 1
|
testProtocolVersion = 1
|
||||||
|
httpPort = "80"
|
||||||
|
httpsPort = "443"
|
||||||
)
|
)
|
||||||
|
|
||||||
getRandomPort := func() string {
|
memNetwork := &memnet.Network{}
|
||||||
ln, err := net.Listen("tcp", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("net.Listen: %v", err)
|
|
||||||
}
|
|
||||||
defer ln.Close()
|
|
||||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
return port
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need consistent ports for each address; these are chosen
|
fallbackAddr := netip.MustParseAddr("10.0.0.1")
|
||||||
// randomly and we hope that they won't conflict during this test.
|
goodAddr := netip.MustParseAddr("10.0.0.2")
|
||||||
httpPort := getRandomPort()
|
otherAddr := netip.MustParseAddr("10.0.0.3")
|
||||||
httpsPort := getRandomPort()
|
other2Addr := netip.MustParseAddr("10.0.0.4")
|
||||||
|
brokenAddr := netip.MustParseAddr("10.0.0.10")
|
||||||
|
slowAddr := netip.MustParseAddr("10.0.0.12")
|
||||||
|
|
||||||
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
|
makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) {
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@ -592,14 +685,8 @@ func TestDialPlan(t *testing.T) {
|
|||||||
handler = wrap(handler)
|
handler = wrap(handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
httpLn, err := net.Listen("tcp", host.String()+":"+httpPort)
|
httpLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpPort))
|
||||||
if err != nil {
|
httpsLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpsPort))
|
||||||
t.Fatalf("HTTP listen: %v", err)
|
|
||||||
}
|
|
||||||
httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("HTTPS listen: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
httpServer := &http.Server{Handler: handler}
|
httpServer := &http.Server{Handler: handler}
|
||||||
go httpServer.Serve(httpLn)
|
go httpServer.Serve(httpLn)
|
||||||
@ -616,209 +703,199 @@ func TestDialPlan(t *testing.T) {
|
|||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
httpsServer.Close()
|
httpsServer.Close()
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fallbackAddr := netip.MustParseAddr("127.0.0.1")
|
// Use synctest's controlled time
|
||||||
goodAddr := netip.MustParseAddr("127.0.0.2")
|
clock := tstime.StdClock{}
|
||||||
otherAddr := netip.MustParseAddr("127.0.0.3")
|
makeHandler(t, "fallback", fallbackAddr, nil)
|
||||||
other2Addr := netip.MustParseAddr("127.0.0.4")
|
makeHandler(t, "good", goodAddr, nil)
|
||||||
brokenAddr := netip.MustParseAddr("127.0.0.10")
|
makeHandler(t, "other", otherAddr, nil)
|
||||||
|
makeHandler(t, "other2", other2Addr, nil)
|
||||||
testCases := []struct {
|
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
|
||||||
name string
|
return brokenMITMHandler(clock)
|
||||||
plan *tailcfg.ControlDialPlan
|
})
|
||||||
wrap func(http.Handler) http.Handler
|
// Create slow listener that delays accept by 5 seconds
|
||||||
want netip.Addr
|
makeSlowHandler := func(t *testing.T, name string, host netip.Addr, delay time.Duration) {
|
||||||
|
done := make(chan struct{})
|
||||||
allowFallback bool
|
t.Cleanup(func() {
|
||||||
}{
|
close(done)
|
||||||
{
|
})
|
||||||
name: "single",
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil)
|
||||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
|
||||||
}},
|
|
||||||
want: goodAddr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "broken-then-good",
|
|
||||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
|
||||||
// Dials the broken one, which fails, and then
|
|
||||||
// eventually dials the good one and succeeds
|
|
||||||
{IP: brokenAddr, Priority: 2, DialTimeoutSec: 10},
|
|
||||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1},
|
|
||||||
}},
|
|
||||||
want: goodAddr,
|
|
||||||
},
|
|
||||||
// TODO(#8442): fix this test
|
|
||||||
// {
|
|
||||||
// name: "multiple-priority-fast-path",
|
|
||||||
// plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
|
||||||
// // Dials some good IPs and our bad one (which
|
|
||||||
// // hangs forever), which then hits the fast
|
|
||||||
// // path where we bail without waiting.
|
|
||||||
// {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10},
|
|
||||||
// {IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
|
||||||
// {IP: other2Addr, Priority: 1, DialTimeoutSec: 10},
|
|
||||||
// {IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
|
||||||
// }},
|
|
||||||
// want: otherAddr,
|
|
||||||
// },
|
|
||||||
{
|
|
||||||
name: "multiple-priority-slow-path",
|
|
||||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
|
||||||
// Our broken address is the highest priority,
|
|
||||||
// so we don't hit our fast path.
|
|
||||||
{IP: brokenAddr, Priority: 10, DialTimeoutSec: 10},
|
|
||||||
{IP: otherAddr, Priority: 2, DialTimeoutSec: 10},
|
|
||||||
{IP: goodAddr, Priority: 1, DialTimeoutSec: 10},
|
|
||||||
}},
|
|
||||||
want: otherAddr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fallback",
|
|
||||||
plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{
|
|
||||||
{IP: brokenAddr, Priority: 1, DialTimeoutSec: 1},
|
|
||||||
}},
|
|
||||||
want: fallbackAddr,
|
|
||||||
allowFallback: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range testCases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// TODO(awly): replace this with tstest.NewClock and update the
|
|
||||||
// test to advance the clock correctly.
|
|
||||||
clock := tstime.StdClock{}
|
|
||||||
makeHandler(t, "fallback", fallbackAddr, nil)
|
|
||||||
makeHandler(t, "good", goodAddr, nil)
|
|
||||||
makeHandler(t, "other", otherAddr, nil)
|
|
||||||
makeHandler(t, "other2", other2Addr, nil)
|
|
||||||
makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler {
|
|
||||||
return brokenMITMHandler(clock)
|
|
||||||
})
|
|
||||||
|
|
||||||
dialer := closeTrackDialer{
|
|
||||||
t: t,
|
|
||||||
inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial,
|
|
||||||
conns: make(map[*closeTrackConn]bool),
|
|
||||||
}
|
|
||||||
defer dialer.Done()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// By default, we intentionally point to something that
|
|
||||||
// we know won't connect, since we want a fallback to
|
|
||||||
// DNS to be an error.
|
|
||||||
host := "example.com"
|
|
||||||
if tt.allowFallback {
|
|
||||||
host = "localhost"
|
|
||||||
}
|
|
||||||
|
|
||||||
drained := make(chan struct{})
|
|
||||||
a := &Dialer{
|
|
||||||
Hostname: host,
|
|
||||||
HTTPPort: httpPort,
|
|
||||||
HTTPSPort: httpsPort,
|
|
||||||
MachineKey: client,
|
|
||||||
ControlKey: server.Public(),
|
|
||||||
ProtocolVersion: testProtocolVersion,
|
|
||||||
Dialer: dialer.Dial,
|
|
||||||
Logf: t.Logf,
|
|
||||||
DialPlan: tt.plan,
|
|
||||||
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
|
|
||||||
drainFinished: drained,
|
|
||||||
omitCertErrorLogging: true,
|
|
||||||
testFallbackDelay: 50 * time.Millisecond,
|
|
||||||
Clock: clock,
|
|
||||||
HealthTracker: health.NewTracker(eventbustest.NewBus(t)),
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := a.dial(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dialing controlhttp: %v", err)
|
log.Print(err)
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
raddr := conn.RemoteAddr().(*net.TCPAddr)
|
|
||||||
|
|
||||||
got, ok := netip.AddrFromSlice(raddr.IP)
|
|
||||||
if !ok {
|
|
||||||
t.Errorf("invalid remote IP: %v", raddr.IP)
|
|
||||||
} else if got != tt.want {
|
|
||||||
t.Errorf("got connection from %q; want %q", got, tt.want)
|
|
||||||
} else {
|
} else {
|
||||||
t.Logf("successfully connected to %q", raddr.String())
|
defer conn.Close()
|
||||||
}
|
}
|
||||||
|
w.Header().Set("X-Handler-Name", name)
|
||||||
|
<-done
|
||||||
|
})
|
||||||
|
|
||||||
// Wait until our dialer drains so we can verify that
|
httpLn, err := memNetwork.Listen("tcp", host.String()+":"+httpPort)
|
||||||
// all connections are closed.
|
if err != nil {
|
||||||
<-drained
|
t.Fatalf("HTTP listen: %v", err)
|
||||||
|
}
|
||||||
|
httpsLn, err := memNetwork.Listen("tcp", host.String()+":"+httpsPort)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("HTTPS listen: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
slowHttpLn := newSlowListener(httpLn, delay)
|
||||||
|
slowHttpsLn := newSlowListener(httpsLn, delay)
|
||||||
|
|
||||||
|
httpServer := &http.Server{Handler: handler}
|
||||||
|
go httpServer.Serve(slowHttpLn)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
httpServer.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
httpsServer := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
TLSConfig: tlsConfig(t),
|
||||||
|
ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")),
|
||||||
|
}
|
||||||
|
go httpsServer.ServeTLS(slowHttpsLn, "", "")
|
||||||
|
t.Cleanup(func() {
|
||||||
|
httpsServer.Close()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
makeSlowHandler(t, "slow", slowAddr, 5*time.Second)
|
||||||
|
|
||||||
|
// memnetDialer with connection tracking, so we can catch connection leaks.
|
||||||
|
dialer := &memnetDialer{
|
||||||
|
inner: memNetwork.Dial,
|
||||||
|
t: t,
|
||||||
|
}
|
||||||
|
defer dialer.waitForAllClosedSynctest()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
host := "example.com"
|
||||||
|
if allowFallback {
|
||||||
|
host = fallbackAddr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &Dialer{
|
||||||
|
Hostname: host,
|
||||||
|
HTTPPort: httpPort,
|
||||||
|
HTTPSPort: httpsPort,
|
||||||
|
MachineKey: client,
|
||||||
|
ControlKey: server.Public(),
|
||||||
|
ProtocolVersion: testProtocolVersion,
|
||||||
|
Dialer: dialer.Dial,
|
||||||
|
Logf: t.Logf,
|
||||||
|
DialPlan: plan,
|
||||||
|
proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil },
|
||||||
|
omitCertErrorLogging: true,
|
||||||
|
testFallbackDelay: 50 * time.Millisecond,
|
||||||
|
Clock: clock,
|
||||||
|
HealthTracker: health.NewTracker(eventbustest.NewBus(t)),
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
conn, err := a.dial(ctx)
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dialing controlhttp: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if maxDuration > 0 && duration > maxDuration {
|
||||||
|
t.Errorf("dial took %v, expected < %v (should not wait for slow endpoints)", duration, maxDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
raddr := conn.RemoteAddr()
|
||||||
|
raddrStr := raddr.String()
|
||||||
|
|
||||||
|
// split on "|" first to remove memnet pipe suffix
|
||||||
|
addrPart := raddrStr
|
||||||
|
if idx := strings.Index(raddrStr, "|"); idx >= 0 {
|
||||||
|
addrPart = raddrStr[:idx]
|
||||||
|
}
|
||||||
|
|
||||||
|
host, _, err2 := net.SplitHostPort(addrPart)
|
||||||
|
if err2 != nil {
|
||||||
|
t.Fatalf("failed to parse remote address %q: %v", addrPart, err2)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err3 := netip.ParseAddr(host)
|
||||||
|
if err3 != nil {
|
||||||
|
t.Errorf("invalid remote IP: %v", host)
|
||||||
|
} else {
|
||||||
|
found := slices.Contains(want, got)
|
||||||
|
if !found {
|
||||||
|
t.Errorf("got connection from %q; want one of %v", got, want)
|
||||||
|
} else {
|
||||||
|
t.Logf("successfully connected to %q", raddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type closeTrackDialer struct {
|
// memnetDialer wraps memnet.Network.Dial to track connections for testing
|
||||||
t testing.TB
|
type memnetDialer struct {
|
||||||
inner netx.DialFunc
|
inner func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||||
|
t *testing.T
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
conns map[*closeTrackConn]bool
|
conns map[net.Conn]string // conn -> remote address for debugging
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
func (d *memnetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
c, err := d.inner(ctx, network, addr)
|
conn, err := d.inner(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ct := &closeTrackConn{Conn: c, d: d}
|
|
||||||
|
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
d.conns[ct] = true
|
if d.conns == nil {
|
||||||
|
d.conns = make(map[net.Conn]string)
|
||||||
|
}
|
||||||
|
d.conns[conn] = conn.RemoteAddr().String()
|
||||||
|
d.t.Logf("tracked connection opened to %s", conn.RemoteAddr())
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
return ct, nil
|
|
||||||
|
return &memnetTrackedConn{Conn: conn, dialer: d}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *closeTrackDialer) Done() {
|
func (d *memnetDialer) waitForAllClosedSynctest() {
|
||||||
// Unfortunately, tsdial.Dialer.SystemDial closes connections
|
const maxWait = 15 * time.Second
|
||||||
// asynchronously in a goroutine, so we can't assume that everything is
|
const checkInterval = 100 * time.Millisecond
|
||||||
// closed by the time we get here.
|
|
||||||
//
|
for range int(maxWait / checkInterval) {
|
||||||
// Sleep/wait a few times on the assumption that things will close
|
|
||||||
// "eventually".
|
|
||||||
const iters = 100
|
|
||||||
for i := range iters {
|
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
if len(d.conns) == 0 {
|
remaining := len(d.conns)
|
||||||
|
if remaining == 0 {
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only error on last iteration
|
|
||||||
if i != iters-1 {
|
|
||||||
d.mu.Unlock()
|
|
||||||
time.Sleep(100 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for conn := range d.conns {
|
|
||||||
d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String())
|
|
||||||
}
|
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
time.Sleep(checkInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
for _, addr := range d.conns {
|
||||||
|
d.t.Errorf("connection to %s was not closed after %v", addr, maxWait)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *closeTrackDialer) noteClose(c *closeTrackConn) {
|
func (d *memnetDialer) noteClose(conn net.Conn) {
|
||||||
d.mu.Lock()
|
d.mu.Lock()
|
||||||
delete(d.conns, c) // safe if already deleted
|
if addr, exists := d.conns[conn]; exists {
|
||||||
|
d.t.Logf("tracked connection closed to %s", addr)
|
||||||
|
delete(d.conns, conn)
|
||||||
|
}
|
||||||
d.mu.Unlock()
|
d.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
type closeTrackConn struct {
|
type memnetTrackedConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
d *closeTrackDialer
|
dialer *memnetDialer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *closeTrackConn) Close() error {
|
func (c *memnetTrackedConn) Close() error {
|
||||||
c.d.noteClose(c)
|
c.dialer.noteClose(c.Conn)
|
||||||
return c.Conn.Close()
|
return c.Conn.Close()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user