From 19defaf6dd292cf5fe021f092d2fe642a1d98877 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 26 Apr 2022 09:00:02 -0700 Subject: [PATCH] wgengine/netstack: close forwarded TCP connections when incoming TCP dies Updates #4522 Change-Id: I31a430da422b1e5fab834a2a670cddf448889ee6 Signed-off-by: Brad Fitzpatrick --- wgengine/netstack/netstack.go | 37 +++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 852c2a720..a8a80f0cf 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -715,29 +715,46 @@ func (ns *Impl) forwardTCP(client *gonet.TCPConn, clientRemoteIP netaddr.IP, wq cancel() }() var stdDialer net.Dialer - server, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) + s, err := stdDialer.DialContext(ctx, "tcp", dialAddrStr) if err != nil { ns.logf("netstack: could not connect to local server at %s: %v", dialAddrStr, err) return } + server := s.(*net.TCPConn) defer server.Close() backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalIPPort, _ := netaddr.FromStdAddr(backendLocalAddr.IP, backendLocalAddr.Port, backendLocalAddr.Zone) ns.e.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) defer ns.e.UnregisterIPPortIdentity(backendLocalIPPort) - connClosed := make(chan error, 2) + go func() { - _, err := io.Copy(server, client) - connClosed <- err + <-ctx.Done() + // Inform the server we won't read anymore, EOF will happen. + server.CloseRead() + // Shutdown the client, but don't close it, otherwise buffers would be dropped. + client.CloseRead() + client.CloseWrite() }() + + var wg sync.WaitGroup + + wg.Add(1) go func() { - _, err := io.Copy(client, server) - connClosed <- err + io.Copy(server, client) + server.CloseRead() + client.CloseWrite() + wg.Done() }() - err = <-connClosed - if err != nil { - ns.logf("proxy connection closed with error: %v", err) - } + + wg.Add(1) + go func() { + io.Copy(client, server) + server.CloseWrite() + client.CloseRead() + wg.Done() + }() + + wg.Wait() ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) }