diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index ffdbec234..1edc21aae 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -40,6 +40,7 @@ type Dialer struct { // NetstackDialTCP dials the provided IPPort using netstack. // If nil, it's not used. NetstackDialTCP func(context.Context, netip.AddrPort) (net.Conn, error) + NetstackDialUDP func(context.Context, netip.AddrPort) (net.Conn, error) peerClientOnce sync.Once peerClient *http.Client @@ -306,10 +307,19 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, return nil, err } if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { - if d.NetstackDialTCP == nil { - return nil, errors.New("Dialer not initialized correctly") + switch network { + case "udp", "udp4", "udp6": + if d.NetstackDialUDP == nil { + return nil, errors.New("Dialer not initialized correctly") + } + return d.NetstackDialUDP(ctx, ipp) + + default: + if d.NetstackDialTCP == nil { + return nil, errors.New("Dialer not initialized correctly") + } + return d.NetstackDialTCP(ctx, ipp) } - return d.NetstackDialTCP(ctx, ipp) } // TODO(bradfitz): netns, etc var stdDialer net.Dialer diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 1ff6b6ecf..678cec539 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -307,7 +307,7 @@ func (s *Server) Close() error { s.mu.Lock() defer s.mu.Unlock() for _, ln := range s.listeners { - ln.Close() + ln.closeUnlocked() } s.listeners = nil @@ -322,6 +322,17 @@ func (s *Server) doInit() { } } +func (s *Server) TailscaleIPs() []netip.Addr { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + st, err := s.localClient.Status(ctx) + if err != nil { + return []netip.Addr{} + } + return st.TailscaleIPs +} + func (s *Server) getAuthKey() string { if v := s.AuthKey; v != "" { return v @@ -440,6 +451,7 @@ func (s *Server) start() (reterr error) { } ns.ProcessLocalIPs = true ns.ForwardTCPIn = s.forwardTCP + ns.ForwardUDPIn = s.forwardUDP s.netstack = ns s.dialer.UseNetstackForIP = func(ip netip.Addr) bool { _, ok := eng.PeerForIP(ip) @@ -448,6 +460,9 @@ func (s *Server) start() (reterr error) { s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return ns.DialContextTCP(ctx, dst) } + s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + return ns.DialContextUDP(ctx, dst) + } if s.Store == nil { stateFile := filepath.Join(s.rootPath, "tailscaled.state") @@ -579,6 +594,24 @@ func (s *Server) forwardTCP(c net.Conn, port uint16) { } } +func (s *Server) forwardUDP(p net.PacketConn, port uint16) { + s.mu.Lock() + ln, ok := s.listeners[listenKey{"udp", "", port}] + s.mu.Unlock() + if !ok { + p.Close() + return + } + + t := time.NewTimer(time.Second) + defer t.Stop() + select { + case ln.pkt <- p: + case <-t.C: + p.Close() + } +} + // getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog) // with no error. // @@ -640,9 +673,12 @@ func (s *Server) APIClient() (*tailscale.Client, error) { // Listen announces only on the Tailscale network. // It will start the server if it has not been started yet. -func (s *Server) Listen(network, addr string) (net.Listener, error) { +func (s *Server) listen(network, addr string) (*listener, error) { + isPacket := false switch network { case "", "tcp", "tcp4", "tcp6": + case "udp", "udp4", "udp6": + isPacket = true default: return nil, errors.New("unsupported network type") } @@ -660,11 +696,13 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { key := listenKey{network, host, uint16(port)} ln := &listener{ - s: s, - key: key, - addr: addr, + s: s, + key: key, + addr: addr, + isPacket: isPacket, conn: make(chan net.Conn), + pkt: make(chan net.PacketConn, 1), } s.mu.Lock() if _, ok := s.listeners[key]; ok { @@ -676,6 +714,19 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) { return ln, nil } +func (s *Server) Listen(network, addr string) (net.Listener, error) { + return s.listen(network, addr) +} + +func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { + ln, err := s.listen(network, addr) + if err != nil { + return nil, err + } + + return ln.GetPacketConn() +} + type listenKey struct { network string host string @@ -683,13 +734,18 @@ type listenKey struct { } type listener struct { - s *Server - key listenKey - addr string - conn chan net.Conn + s *Server + key listenKey + addr string + isPacket bool + conn chan net.Conn + pkt chan net.PacketConn } func (ln *listener) Accept() (net.Conn, error) { + if ln.isPacket { + return nil, fmt.Errorf("tsnet: listener is for packets (UDP, not TCP)") + } c, ok := <-ln.conn if !ok { return nil, fmt.Errorf("tsnet: %w", net.ErrClosed) @@ -697,13 +753,29 @@ func (ln *listener) Accept() (net.Conn, error) { return c, nil } +func (ln *listener) GetPacketConn() (net.PacketConn, error) { + if !ln.isPacket { + return nil, fmt.Errorf("tsnet: listener is for connections (TCP, not UDP)") + } + + p, ok := <-ln.pkt + if !ok { + return nil, fmt.Errorf("tsnet: %w", net.ErrClosed) + } + return p, nil +} + func (ln *listener) Addr() net.Addr { return addr{ln} } func (ln *listener) Close() error { ln.s.mu.Lock() defer ln.s.mu.Unlock() + return ln.closeUnlocked() +} +func (ln *listener) closeUnlocked() error { if v, ok := ln.s.listeners[ln.key]; ok && v == ln { delete(ln.s.listeners, ln.key) close(ln.conn) + close(ln.pkt) } return nil } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 3847b4358..7207852fb 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -13,6 +13,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "reflect" "testing" "time" @@ -87,31 +88,31 @@ func startControl(t *testing.T) (controlURL string) { return controlURL } -func TestConn(t *testing.T) { +func setupTwoNodes(t *testing.T) (s1, s2 *Server, ctx context.Context) { controlURL := startControl(t) tmp := t.TempDir() tmps1 := filepath.Join(tmp, "s1") os.MkdirAll(tmps1, 0755) - s1 := &Server{ + s1 = &Server{ Dir: tmps1, ControlURL: controlURL, Hostname: "s1", Store: new(mem.Store), Ephemeral: true, } - defer s1.Close() + t.Cleanup(func() { s1.Close() }) tmps2 := filepath.Join(tmp, "s1") os.MkdirAll(tmps2, 0755) - s2 := &Server{ + s2 = &Server{ Dir: tmps2, ControlURL: controlURL, Hostname: "s2", Store: new(mem.Store), Ephemeral: true, } - defer s2.Close() + t.Cleanup(func() { s2.Close() }) if !*verboseNodes { s1.Logf = logger.Discard @@ -119,7 +120,7 @@ func TestConn(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + t.Cleanup(cancel) s1status, err := s1.Up(ctx) if err != nil { @@ -142,6 +143,12 @@ func TestConn(t *testing.T) { } t.Logf("ping success: %#+v", res) + return +} + +func TestConn(t *testing.T) { + s1, s2, ctx := setupTwoNodes(t) + // pass some data through TCP. ln, err := s1.Listen("tcp", ":8081") if err != nil { @@ -149,6 +156,7 @@ func TestConn(t *testing.T) { } defer ln.Close() + s1ip := s1.TailscaleIPs()[0] w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) if err != nil { t.Fatal(err) @@ -174,6 +182,42 @@ func TestConn(t *testing.T) { } } +func TestPackets(t *testing.T) { + s1, s2, ctx := setupTwoNodes(t) + + want := "PACKET!!" + received := make(chan []byte) + go func() { + p, err := s1.ListenPacket("udp", ":5000") + if err != nil { + t.Fatal(err) + } + defer p.Close() + + buf := make([]byte, len(want)) + if _, _, err := p.ReadFrom(buf); err != nil { + t.Fatal(err) + } + received <- buf + }() + + s1ip := s1.TailscaleIPs()[0] + w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:5000", s1ip)) + if err != nil { + t.Fatal(err) + } + + if _, err := io.WriteString(w, want); err != nil { + t.Fatal(err) + } + + got := <-received + t.Logf("got: %q", got) + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } +} + func TestLoopbackLocalAPI(t *testing.T) { controlURL := startControl(t) @@ -258,3 +302,32 @@ func TestLoopbackLocalAPI(t *testing.T) { t.Errorf("GET /status returned %d, want 200", res.StatusCode) } } + +func TestTailscaleIPs(t *testing.T) { + controlURL := startControl(t) + + tmp := t.TempDir() + tmps1 := filepath.Join(tmp, "s1") + os.MkdirAll(tmps1, 0755) + s1 := &Server{ + Dir: tmps1, + ControlURL: controlURL, + Hostname: "s1", + Store: new(mem.Store), + Ephemeral: true, + } + defer s1.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + s1status, err := s1.Up(ctx) + if err != nil { + t.Fatal(err) + } + + ips := s1.TailscaleIPs() + if !reflect.DeepEqual(ips, s1status.TailscaleIPs) { + t.Errorf("s1.TailscaleIPs returned a different result than S1.Up, %v != %v", ips, s1status.TailscaleIPs) + } +}