diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 2d5e1a6cf..a5ca42afa 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -57,6 +57,13 @@ var parsedPacketPool = sync.Pool{New: func() interface{} { return new(packet.Par type FilterFunc func(*packet.Parsed, *Wrapper) filter.Response // Wrapper augments a tun.Device with packet filtering and injection. +// +// Wrapper is on the hot path for packets flowing through Tailscale: +// Wrapper.Read and Wrapper.Write execute for every single packet. +// Wrapper is thus written with an eye towards performance. +// In particular, it could be made simpler and cleaner by using +// multi-channel selects to handle shutdown, but those are +// unfortuately not suitable for performance reasons. type Wrapper struct { logf logger.Logf // tdev is the underlying Wrapper device. @@ -71,13 +78,22 @@ type Wrapper struct { // buffer stores the oldest unconsumed packet from tdev. // It is made a static buffer in order to avoid allocations. buffer [maxBufferSize]byte - // bufferConsumed synchronizes access to buffer (shared by Read and poll). - bufferConsumed chan struct{} - - // closed signals poll (by closing) when the device is closed. - closed chan struct{} - // errors is the error queue populated by poll. - errors chan error + // bufferC coordinates access to buffer between Read and poll. + // Read and poll take turns using buffer: poll fills it and Read empties it. + // A nil buffer passed over bufferC tells poll to exit. + // + // bufferC must be buffered. See the comment in Wrap. + // + // bufferC is a chan []byte instead of a chan struct{} for two reasons. + // (1) []byte can be nil, which provides a useful sentinel value. + // (2) It is a step towards slightly decoupling Read and poll, + // thereby allowing us to use multiple buffers. + // Letting poll use multiple buffers will let us issue syscalls sooner, + // which is important to performance, as syscalls are the + // single slowest thing that Tailscale does, by a large margin. + bufferC chan []byte + // done signals (by closing) when the device is closed. + done chan struct{} // outbound is the queue by which packets leave the TUN device. // // The directions are relative to the network, not the device: @@ -88,7 +104,11 @@ type Wrapper struct { // // Empty reads are skipped by Wireguard, so it is always legal // to discard an empty packet instead of sending it through t.outbound. - outbound chan []byte + outbound chan tunReadResult + // injectOutboundMu serializes attempts to inject a packet. + // This ensures that there is at most one outstanding attempt to send on outbound, + // which is required to guarantee that Close does not block. + injectOutboundMu sync.Mutex // eventsUpDown yields up and down tun.Events that arrive on a Wrapper's events channel. eventsUpDown chan tun.Event @@ -125,26 +145,34 @@ type Wrapper struct { disableTSMPRejected bool } +// A tunReadResult is the result of a read from a TUN device. +type tunReadResult struct { + data []byte + err error + injected bool +} + func Wrap(logf logger.Logf, tdev tun.Device) *Wrapper { tun := &Wrapper{ logf: logger.WithPrefix(logf, "tstun: "), tdev: tdev, - // bufferConsumed is conceptually a condition variable: - // a goroutine should not block when setting it, even with no listeners. - bufferConsumed: make(chan struct{}, 1), - closed: make(chan struct{}), - errors: make(chan error), - outbound: make(chan []byte), - eventsUpDown: make(chan tun.Event), - eventsOther: make(chan tun.Event), + // bufferC needs to be able to accomodate three writes without blocking: + // one from Wrap (a few lines down), one from Read, and one from Close. + bufferC: make(chan []byte, 3), + done: make(chan struct{}), + // outbound needs to be able to accomodate three writes without blocking: + // one from poll, one from Close, and one from InjectOutbound. + outbound: make(chan tunReadResult, 3), + eventsUpDown: make(chan tun.Event), + eventsOther: make(chan tun.Event), // TODO(dmytro): (highly rate-limited) hexdumps should happen on unknown packets. filterFlags: filter.LogAccepts | filter.LogDrops, } go tun.poll() go tun.pumpEvents() - // The buffer starts out consumed. - tun.bufferConsumed <- struct{}{} + // Provide the initial buffer to poll. + tun.bufferC <- tun.buffer[:] return tun } @@ -160,14 +188,28 @@ func (t *Wrapper) SetDestIPActivityFuncs(m map[netaddr.IP]func()) { func (t *Wrapper) Close() error { var err error t.closeOnce.Do(func() { - // Other channels need not be closed: poll will exit gracefully after this. - close(t.closed) + close(t.done) + // Each channel is buffered enough to guarantee that sends will not block. + // Signal poll to stop. + t.bufferC <- nil + // Signal Read to stop. + t.outbound <- tunReadResult{err: io.EOF} err = t.tdev.Close() }) return err } +// closed reports whether t is closed. +func (t *Wrapper) closed() bool { + select { + case <-t.done: + return true + default: + return false + } +} + // pumpEvents copies events from t.tdev to t.eventsUpDown and t.eventsOther. // pumpEvents exits when t.tdev.events or t.closed is closed. // pumpEvents closes t.eventsUpDown and t.eventsOther when it exits. @@ -180,7 +222,7 @@ func (t *Wrapper) pumpEvents() { var event tun.Event var ok bool select { - case <-t.closed: + case <-t.done: return case event, ok = <-src: if !ok { @@ -195,7 +237,7 @@ func (t *Wrapper) pumpEvents() { dst = t.eventsUpDown } select { - case <-t.closed: + case <-t.done: return case dst <- event: } @@ -235,40 +277,29 @@ func (t *Wrapper) Name() (string, error) { // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. func (t *Wrapper) poll() { for { - select { - case <-t.closed: + if t.closed() { + return + } + buf := <-t.bufferC + // nil buffer means t is closed. + if buf == nil { return - case <-t.bufferConsumed: - // continue } // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. - n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) - if err != nil { - select { - case <-t.closed: - return - case t.errors <- err: - // In principle, read errors are not fatal (but wireguard-go disagrees). - t.bufferConsumed <- struct{}{} - } - continue - } - + // In principle, read errors are not fatal (but wireguard-go disagrees). + n, err := t.tdev.Read(buf, PacketStartOffset) // Wireguard will skip an empty read, // so we might as well do it here to avoid the send through t.outbound. - if n == 0 { - t.bufferConsumed <- struct{}{} + if n == 0 && err == nil { + t.bufferC <- buf continue } - - select { - case <-t.closed: - return - case t.outbound <- t.buffer[PacketStartOffset : PacketStartOffset+n]: - // continue + t.outbound <- tunReadResult{ + data: buf[PacketStartOffset : PacketStartOffset+n], + err: err, } } } @@ -325,26 +356,17 @@ func (t *Wrapper) IdleDuration() time.Duration { } func (t *Wrapper) Read(buf []byte, offset int) (int, error) { - var n int - - wasInjectedPacket := false - - select { - case <-t.closed: + if t.closed() { return 0, io.EOF - case err := <-t.errors: - return 0, err - case pkt := <-t.outbound: - n = copy(buf[offset:], pkt) - // t.buffer has a fixed location in memory, - // so this is the easiest way to tell when it has been consumed. - // &pkt[0] can be used because empty packets do not reach t.outbound. - if &pkt[0] == &t.buffer[PacketStartOffset] { - t.bufferConsumed <- struct{}{} - } else { - // If the packet is not from t.buffer, then it is an injected packet. - wasInjectedPacket = true - } + } + res := <-t.outbound + if res.err != nil { + return 0, res.err + } + n := copy(buf[offset:], res.data) + if !res.injected { + // Return the buffer to poll to re-fill. + t.bufferC <- t.buffer[:] } p := parsedPacketPool.Get().(*packet.Parsed) @@ -358,7 +380,7 @@ func (t *Wrapper) Read(buf []byte, offset int) (int, error) { } // Do not filter injected packets. - if !wasInjectedPacket && !t.disableFilter { + if !res.injected && !t.disableFilter { response := t.filterOut(p) if response != filter.Accept { // Wireguard considers read errors fatal; pretend nothing was read @@ -555,18 +577,19 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque // The injected packet will not pass through outbound filters. // Injecting an empty packet is a no-op. func (t *Wrapper) InjectOutbound(packet []byte) error { + t.injectOutboundMu.Lock() + defer t.injectOutboundMu.Unlock() + if t.closed() { + return ErrClosed + } if len(packet) > MaxPacketSize { return errPacketTooBig } if len(packet) == 0 { return nil } - select { - case <-t.closed: - return ErrClosed - case t.outbound <- packet: - return nil - } + t.outbound <- tunReadResult{data: packet, injected: true} + return nil } // Unwrap returns the underlying tun.Device. diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 1875e897b..65969a878 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -309,7 +309,7 @@ func TestFilter(t *testing.T) { var recvbuf []byte for { select { - case <-tun.closed: + case <-tun.done: return case recvbuf = <-chtun.Inbound: // continue