diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 852c2a720..a8a80f0cf 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -715,29 +715,46 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netaddr.IP, wq cancel() }() var stdDialer net.Dialer - server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) + s, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) if err != nil { ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err) return } + server := s.(*net.TCPConn) defer server.Close() backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort) - connClosed := make(chan error, 2) + go func() { - _, err := io.Copy(server, client) - connClosed <- err + <-ctx.Done() + // Inform the server we won't read anymore, EOF will happen. + server.CloseRead() + // Shutdown the client, but don't close it, otherwise buffers would be dropped. + client.CloseRead() + client.CloseWrite() }() + + var wg sync.WaitGroup + + wg.Add(1) go func() { - _, err := io.Copy(client, server) - connClosed <- err + io.Copy(server, client) + server.CloseRead() + client.CloseWrite() + wg.Done() }() - err = <-connClosed - if err != nil { - ns.logf("proxy connection closed with error: %v", err) - } + + wg.Add(1) + go func() { + io.Copy(client, server) + server.CloseWrite() + client.CloseRead() + wg.Done() + }() + + wg.Wait() ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) }