diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 07f5958ca..e60c1cb04 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -311,7 +311,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ tailscale.com/metrics from tailscale.com/derp+ tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ - 💣 tailscale.com/net/batching from tailscale.com/wgengine/magicsock + 💣 tailscale.com/net/batching from tailscale.com/wgengine/magicsock+ tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ diff --git a/net/batching/conn.go b/net/batching/conn.go index 2c6100258..77cdf8c84 100644 --- a/net/batching/conn.go +++ b/net/batching/conn.go @@ -32,7 +32,6 @@ type Conn interface { // message may fall on either side of a nonzero. // // Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize(). - // len(msgs) must be at least MinReadBatchMsgsLen(). ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) // WriteBatchTo writes buffs to addr. // diff --git a/net/batching/conn_default.go b/net/batching/conn_default.go index ed5c494f3..37d644f50 100644 --- a/net/batching/conn_default.go +++ b/net/batching/conn_default.go @@ -19,3 +19,5 @@ var controlMessageSize = 0 func MinControlMessageSize() int { return controlMessageSize } + +const IdealBatchSize = 1 diff --git a/net/batching/conn_linux.go b/net/batching/conn_linux.go index 09a80ed9f..7f6c4ed42 100644 --- a/net/batching/conn_linux.go +++ b/net/batching/conn_linux.go @@ -384,7 +384,7 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) { } // TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades -// pconn to a [Conn] if appropriate. A batch size of MinReadBatchMsgsLen() is +// pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is // suggested for the best performance. func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { if runtime.GOOS != "linux" { @@ -457,6 +457,4 @@ func MinControlMessageSize() int { return controlMessageSize } -func MinReadBatchMsgsLen() int { - return 128 -} +const IdealBatchSize = 128 diff --git a/net/batching/conn_linux_test.go b/net/batching/conn_linux_test.go index e33ad6d7a..e518c3f9f 100644 --- a/net/batching/conn_linux_test.go +++ b/net/batching/conn_linux_test.go @@ -310,7 +310,7 @@ func TestMinReadBatchMsgsLen(t *testing.T) { // So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is // shaped for wireguard-go to control packet memory, these values should be // aligned. - if MinReadBatchMsgsLen() != conn.IdealBatchSize { - t.Fatalf("MinReadBatchMsgsLen():%d != conn.IdealBatchSize(): %d", MinReadBatchMsgsLen(), conn.IdealBatchSize) + if IdealBatchSize != conn.IdealBatchSize { + t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize) } } diff --git a/net/udprelay/server.go b/net/udprelay/server.go index e138c33f2..a039c9930 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -20,8 +20,11 @@ import ( "time" "go4.org/mem" + "golang.org/x/net/ipv6" "tailscale.com/client/local" "tailscale.com/disco" + "tailscale.com/net/batching" + "tailscale.com/net/netaddr" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" "tailscale.com/net/packet" @@ -57,10 +60,10 @@ type Server struct { bindLifetime time.Duration steadyStateLifetime time.Duration bus *eventbus.Bus - uc4 *net.UDPConn // always non-nil - uc4Port uint16 // always nonzero - uc6 *net.UDPConn // may be nil if IPv6 bind fails during initialization - uc6Port uint16 // may be zero if IPv6 bind fails during initialization + uc4 batching.Conn // always non-nil + uc4Port uint16 // always nonzero + uc6 batching.Conn // may be nil if IPv6 bind fails during initialization + uc6Port uint16 // may be zero if IPv6 bind fails during initialization closeOnce sync.Once wg sync.WaitGroup closeCh chan struct{} @@ -96,9 +99,9 @@ type serverEndpoint struct { allocatedAt time.Time } -func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, conn *net.UDPConn, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { if senderIndex != 0 && senderIndex != 1 { - return + return nil, netip.AddrPort{} } otherSender := 0 @@ -121,15 +124,15 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) if err != nil { // silently drop - return + return nil, netip.AddrPort{} } if discoMsg.Generation == 0 { // Generation must be nonzero, silently drop - return + return nil, netip.AddrPort{} } if e.handshakeGeneration[senderIndex] == discoMsg.Generation { // we've seen this generation before, silently drop - return + return nil, netip.AddrPort{} } e.handshakeGeneration[senderIndex] = discoMsg.Generation e.handshakeAddrPorts[senderIndex] = from @@ -144,19 +147,18 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex gh.VNI.Set(e.vni) err = gh.Encode(reply) if err != nil { - return + return nil, netip.AddrPort{} } reply = append(reply, disco.Magic...) reply = serverDisco.AppendTo(reply) box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) reply = append(reply, box...) - conn.WriteMsgUDPAddrPort(reply, nil, from) - return + return reply, from case *disco.BindUDPRelayEndpointAnswer: err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) if err != nil { // silently drop - return + return nil, netip.AddrPort{} } generation := e.handshakeGeneration[senderIndex] if generation == 0 || // we have no active handshake @@ -164,23 +166,23 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake // silently drop - return + return nil, netip.AddrPort{} } // Handshake complete. Update the binding for this sender. e.boundAddrPorts[senderIndex] = from e.lastSeen[senderIndex] = time.Now() // record last seen as bound time - return + return nil, netip.AddrPort{} default: // unexpected message types, silently drop - return + return nil, netip.AddrPort{} } } -func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, conn *net.UDPConn, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { senderRaw, isDiscoMsg := disco.Source(b) if !isDiscoMsg { // Not a Disco message - return + return nil, netip.AddrPort{} } sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) senderIndex := -1 @@ -191,63 +193,51 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by senderIndex = 1 default: // unknown Disco public key - return + return nil, netip.AddrPort{} } const headerLen = len(disco.Magic) + key.DiscoPublicRawLen discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:]) if !ok { // unable to decrypt the Disco payload - return + return nil, netip.AddrPort{} } discoMsg, err := disco.Parse(discoPayload) if err != nil { // unable to parse the Disco payload - return + return nil, netip.AddrPort{} } - e.handleDiscoControlMsg(from, senderIndex, discoMsg, conn, serverDisco) + return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco) } -func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, rxSocket, otherAFSocket *net.UDPConn, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { if !gh.Control { if !e.isBound() { // not a control packet, but serverEndpoint isn't bound - return + return nil, netip.AddrPort{} } - var to netip.AddrPort switch { case from == e.boundAddrPorts[0]: e.lastSeen[0] = time.Now() - to = e.boundAddrPorts[1] + return b, e.boundAddrPorts[1] case from == e.boundAddrPorts[1]: e.lastSeen[1] = time.Now() - to = e.boundAddrPorts[0] + return b, e.boundAddrPorts[0] default: // unrecognized source - return + return nil, netip.AddrPort{} } - // Relay the packet towards the other party via the socket associated - // with the destination's address family. If source and destination - // address families are matching we tx on the same socket the packet - // was received (rxSocket), otherwise we use the "other" socket - // (otherAFSocket). [Server] makes no use of dual-stack sockets. - if from.Addr().Is4() == to.Addr().Is4() { - rxSocket.WriteMsgUDPAddrPort(b, nil, to) - } else if otherAFSocket != nil { - otherAFSocket.WriteMsgUDPAddrPort(b, nil, to) - } - return } if gh.Protocol != packet.GeneveProtocolDisco { // control packet, but not Disco - return + return nil, netip.AddrPort{} } msg := b[packet.GeneveFixedHeaderLength:] - e.handleSealedDiscoControlMsg(from, msg, rxSocket, serverDisco) + return e.handleSealedDiscoControlMsg(from, msg, serverDisco) } func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { @@ -338,10 +328,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve } s.wg.Add(1) - go s.packetReadLoop(s.uc4, s.uc6) + go s.packetReadLoop(s.uc4, s.uc6, true) if s.uc6 != nil { s.wg.Add(1) - go s.packetReadLoop(s.uc6, s.uc4) + go s.packetReadLoop(s.uc6, s.uc4, false) } s.wg.Add(1) go s.endpointGCLoop() @@ -425,6 +415,41 @@ func (s *Server) addrDiscoveryLoop() { } } +// This is a compile-time assertion that [singlePacketConn] implements the +// [batching.Conn] interface. +var _ batching.Conn = (*singlePacketConn)(nil) + +// singlePacketConn implements [batching.Conn] with single packet syscall +// operations. +type singlePacketConn struct { + *net.UDPConn +} + +func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) { + n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0]) + if err != nil { + return 0, err + } + msgs[0].N = n + msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap)) + return 1, nil +} + +func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error { + for _, buff := range buffs { + if geneve.VNI.IsSet() { + geneve.Encode(buff) + } else { + buff = buff[offset:] + } + _, err := c.UDPConn.WriteToUDPAddrPort(buff, addr) + if err != nil { + return err + } + } + return nil +} + // listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if // we manage to bind the IPv4 socket. // @@ -433,7 +458,10 @@ func (s *Server) addrDiscoveryLoop() { // across IPv4 and IPv6 if the requested port is zero. // // TODO: make these "re-bindable" in similar fashion to magicsock as a means to -// deal with EDR software closing them. http://go/corp/30118 +// deal with EDR software closing them. http://go/corp/30118. We could re-use +// [magicsock.RebindingConn], which would also remove the need for +// [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to +// single packet syscall operations. func (s *Server) listenOn(port int) error { for _, network := range []string{"udp4", "udp6"} { uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) @@ -462,11 +490,16 @@ func (s *Server) listenOn(port int) error { } return err } + pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) + bc, ok := pc.(batching.Conn) + if !ok { + bc = &singlePacketConn{uc} + } if network == "udp4" { - s.uc4 = uc + s.uc4 = bc s.uc4Port = uint16(portUint) } else { - s.uc6 = uc + s.uc6 = bc s.uc6Port = uint16(portUint) } } @@ -526,18 +559,18 @@ func (s *Server) endpointGCLoop() { } } -func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSocket *net.UDPConn) { +func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to netip.AddrPort) { if stun.Is(b) && b[1] == 0x01 { // A b[1] value of 0x01 (STUN method binding) is sufficiently // non-overlapping with the Geneve header where the LSB is always 0 // (part of 6 "reserved" bits). s.netChecker.ReceiveSTUNPacket(b, from) - return + return nil, netip.AddrPort{} } gh := packet.GeneveHeader{} err := gh.Decode(b) if err != nil { - return + return nil, netip.AddrPort{} } // TODO: consider performance implications of holding s.mu for the remainder // of this method, which does a bunch of disco/crypto work depending. Keep @@ -547,13 +580,13 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSo e, ok := s.byVNI[gh.VNI.Get()] if !ok { // unknown VNI - return + return nil, netip.AddrPort{} } - e.handlePacket(from, gh, b, rxSocket, otherAFSocket, s.discoPublic) + return e.handlePacket(from, gh, b, s.discoPublic) } -func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) { +func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) { defer func() { // We intentionally close the [Server] if we encounter a socket read // error below, at least until socket "re-binding" is implemented as @@ -564,15 +597,73 @@ func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) { s.wg.Done() s.Close() }() - b := make([]byte, 1<<16-1) + + msgs := make([]ipv6.Message, batching.IdealBatchSize) + for i := range msgs { + msgs[i].OOB = make([]byte, batching.MinControlMessageSize()) + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Buffers[0] = make([]byte, 1<<16-1) + } + writeBuffsByDest := make(map[netip.AddrPort][][]byte, batching.IdealBatchSize) + for { + for i := range msgs { + msgs[i] = ipv6.Message{Buffers: msgs[i].Buffers, OOB: msgs[i].OOB[:cap(msgs[i].OOB)]} + } + // TODO: extract laddr from IP_PKTINFO for use in reply - n, from, err := readFromSocket.ReadFromUDPAddrPort(b) + // ReadBatch will split coalesced datagrams before returning, which + // WriteBatchTo will re-coalesce further down. We _could_ be more + // efficient and not split datagrams that belong to the same VNI if they + // are non-control/handshake packets. We pay the memmove/memcopy + // performance penalty for now in the interest of simple single packet + // handlers. + n, err := readFromSocket.ReadBatch(msgs, 0) if err != nil { s.logf("error reading from socket(%v): %v", readFromSocket.LocalAddr(), err) return } - s.handlePacket(from, b[:n], readFromSocket, otherSocket) + + for _, msg := range msgs[:n] { + if msg.N == 0 { + continue + } + buf := msg.Buffers[0][:msg.N] + from := msg.Addr.(*net.UDPAddr).AddrPort() + write, to := s.handlePacket(from, buf) + if !to.IsValid() { + continue + } + if from.Addr().Is4() == to.Addr().Is4() || otherSocket != nil { + buffs, ok := writeBuffsByDest[to] + if !ok { + buffs = make([][]byte, 0, batching.IdealBatchSize) + } + buffs = append(buffs, write) + writeBuffsByDest[to] = buffs + } else { + // This is unexpected. We should never produce a packet to write + // to the "other" socket if the other socket is nil/unbound. + // [server.handlePacket] has to see a packet from a particular + // address family at least once in order for it to return a + // packet to write towards a dest for the same address family. + s.logf("[unexpected] packet from: %v produced packet to: %v while otherSocket is nil", from, to) + } + } + + for dest, buffs := range writeBuffsByDest { + // Write the packet batches via the socket associated with the + // destination's address family. If source and destination address + // families are matching we tx on the same socket the packet was + // received, otherwise we use the "other" socket. [Server] makes no + // use of dual-stack sockets. + if dest.Addr().Is4() == readFromSocketIsIPv4 { + readFromSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0) + } else { + otherSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0) + } + delete(writeBuffsByDest, dest) + } } }