From 7ec43bd177bf781aa1917ff7df59b46d17e220d9 Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Thu, 8 Feb 2024 10:15:27 -0800 Subject: [PATCH] Fix blocked streams blocking reconnects (#19017) We have observed cases where a blocked stream will block for cancellations. This happens when response channel is blocked and we want to push an error. This will have the response mutex locked, which will prevent all other operations until upstream is unblocked. Make this behavior non-blocking and if blocked spawn a goroutine that will send the response and close the output. Still a lot of "dancing". Added a test for this and reviewed. --- internal/grid/connection.go | 12 +++++ internal/grid/debug.go | 1 + internal/grid/grid_test.go | 96 ++++++++++++++++++++++++++++++++++++ internal/grid/muxclient.go | 98 +++++++++++++++++++++---------------- internal/grid/stream.go | 3 +- 5 files changed, 168 insertions(+), 42 deletions(-) diff --git a/internal/grid/connection.go b/internal/grid/connection.go index b9d7f893e..188ee04e1 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -1587,6 +1587,18 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) { c.clientPingInterval = args[0].(time.Duration) case debugAddToDeadline: c.addDeadline = args[0].(time.Duration) + case debugIsOutgoingClosed: + // params: muxID uint64, isClosed func(bool) + muxID := args[0].(uint64) + resp := args[1].(func(b bool)) + mid, ok := c.outgoing.Load(muxID) + if !ok || mid == nil { + resp(true) + return + } + mid.respMu.Lock() + resp(mid.closed) + mid.respMu.Unlock() } } diff --git a/internal/grid/debug.go b/internal/grid/debug.go index eddb577e7..0172f87e2 100644 --- a/internal/grid/debug.go +++ b/internal/grid/debug.go @@ -49,6 +49,7 @@ const ( debugSetConnPingDuration debugSetClientPingDuration debugAddToDeadline + debugIsOutgoingClosed ) // TestGrid contains a grid of servers for testing purposes. diff --git a/internal/grid/grid_test.go b/internal/grid/grid_test.go index 5c9942c9f..75ce9d35d 100644 --- a/internal/grid/grid_test.go +++ b/internal/grid/grid_test.go @@ -372,6 +372,12 @@ func TestStreamSuite(t *testing.T) { assertNoActive(t, connRemoteLocal) assertNoActive(t, connLocalToRemote) }) + t.Run("testServerStreamResponseBlocked", func(t *testing.T) { + defer timeout(1 * time.Minute)() + testServerStreamResponseBlocked(t, local, remote) + assertNoActive(t, connRemoteLocal) + assertNoActive(t, connLocalToRemote) + }) } func testStreamRoundtrip(t *testing.T, local, remote *Manager) { @@ -929,6 +935,96 @@ func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) { t.Log("EOF.", payloads, " Roundtrips:", time.Since(start)) } +// testServerStreamResponseBlocked will test if server can handle a blocked response stream +func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) { + defer testlogger.T.SetErrorTB(t)() + errFatal := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + + // We fake a local and remote server. + remoteHost := remote.HostName() + + // 1: Echo + serverSent := make(chan struct{}) + serverCanceled := make(chan struct{}) + register := func(manager *Manager) { + errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{ + Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr { + // Send many responses. + // Test that this doesn't block. + for i := byte(0); i < 100; i++ { + select { + case resp <- []byte{i}: + // ok + case <-ctx.Done(): + close(serverCanceled) + return NewRemoteErr(ctx.Err()) + } + if i == 1 { + close(serverSent) + } + } + return nil + }, + OutCapacity: 1, + InCapacity: 0, + })) + } + register(local) + register(remote) + + remoteConn := local.Connection(remoteHost) + const testPayload = "Hello Grid World!" + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + + st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload)) + errFatal(err) + + // Wait for the server to send the first response. + <-serverSent + + // Read back from the stream and block. + nowBlocking := make(chan struct{}) + stopBlocking := make(chan struct{}) + defer close(stopBlocking) + go func() { + st.Results(func(b []byte) error { + close(nowBlocking) + // Block until test is done. + <-stopBlocking + return nil + }) + }() + + <-nowBlocking + // Wait for the receiver channel to fill. + for len(st.responses) != cap(st.responses) { + time.Sleep(time.Millisecond) + } + cancel() + <-serverCanceled + local.debugMsg(debugIsOutgoingClosed, st.muxID, func(closed bool) { + if !closed { + t.Error("expected outgoing closed") + } else { + t.Log("outgoing was closed") + } + }) + + // Drain responses and check if error propagated. + err = st.Results(func(b []byte) error { + return nil + }) + if !errors.Is(err, context.Canceled) { + t.Error("expected context.Canceled, got", err) + } +} + func timeout(after time.Duration) (cancel func()) { c := time.After(after) cc := make(chan struct{}) diff --git a/internal/grid/muxclient.go b/internal/grid/muxclient.go index 9425a86a6..1b380a341 100644 --- a/internal/grid/muxclient.go +++ b/internal/grid/muxclient.go @@ -50,6 +50,7 @@ type muxClient struct { deadline time.Duration outBlock chan struct{} subroute *subHandlerID + respErr atomic.Pointer[error] } // Response is a response from the server. @@ -250,25 +251,52 @@ func (m *muxClient) RequestStream(h HandlerID, payload []byte, requests chan []b // Spawn simple disconnect if requests == nil { - start := time.Now() - go m.handleOneWayStream(start, responseCh, responses) - return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn}, nil + go m.handleOneWayStream(responseCh, responses) + return &Stream{responses: responseCh, Requests: nil, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } // Deliver responses and send unblocks back to the server. go m.handleTwowayResponses(responseCh, responses) go m.handleTwowayRequests(responses, requests) - return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn}, nil + return &Stream{responses: responseCh, Requests: requests, ctx: m.ctx, cancel: m.cancelFn, muxID: m.MuxID}, nil } -func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Response, respServer <-chan Response) { +func (m *muxClient) addErrorNonBlockingClose(respHandler chan<- Response, err error) { + m.respMu.Lock() + defer m.respMu.Unlock() + if !m.closed { + m.respErr.Store(&err) + // Do not block. + select { + case respHandler <- Response{Err: err}: + xioutil.SafeClose(respHandler) + default: + go func() { + respHandler <- Response{Err: err} + xioutil.SafeClose(respHandler) + }() + } + logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) + m.closed = true + } +} + +// respHandler +func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <-chan Response) { if debugPrint { + start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() } - defer xioutil.SafeClose(respHandler) + defer func() { + // addErrorNonBlockingClose will close the response channel + // - maybe async, so we shouldn't do it here. + if m.respErr.Load() == nil { + xioutil.SafeClose(respHandler) + } + }() var pingTimer <-chan time.Time if m.deadline == 0 || m.deadline > clientPingInterval { ticker := time.NewTicker(clientPingInterval) @@ -283,13 +311,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.respMu.Lock() - defer m.respMu.Unlock() // We always return in this path. - if !m.closed { - respHandler <- Response{Err: context.Cause(m.ctx)} - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - m.closeLocked() - } + m.addErrorNonBlockingClose(respHandler, context.Cause(m.ctx)) return case resp, ok := <-respServer: if !ok { @@ -308,13 +330,7 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo } case <-pingTimer: if time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)) > clientPingInterval*2 { - m.respMu.Lock() - defer m.respMu.Unlock() // We always return in this path. - if !m.closed { - respHandler <- Response{Err: ErrDisconnected} - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - m.closeLocked() - } + m.addErrorNonBlockingClose(respHandler, ErrDisconnected) return } // Send new ping. @@ -323,19 +339,21 @@ func (m *muxClient) handleOneWayStream(start time.Time, respHandler chan<- Respo } } -func (m *muxClient) handleTwowayResponses(responseCh chan Response, responses chan Response) { +// responseCh is the channel to that goes to the requester. +// internalResp is the channel that comes from the server. +func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) { defer m.parent.deleteMux(false, m.MuxID) defer xioutil.SafeClose(responseCh) - for resp := range responses { + for resp := range internalResp { responseCh <- resp m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID}) } } -func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests chan []byte) { +func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) { var errState bool - start := time.Now() if debugPrint { + start := time.Now() defer func() { fmt.Println("Mux", m.MuxID, "Request took", time.Since(start).Round(time.Millisecond)) }() @@ -343,19 +361,22 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha // Listen for client messages. for { + if errState { + go func() { + // Drain requests. + for range requests { + } + }() + return + } select { case <-m.ctx.Done(): if debugPrint { fmt.Println("Client sending disconnect to mux", m.MuxID) } - m.respMu.Lock() - defer m.respMu.Unlock() - logger.LogIf(m.ctx, m.sendLocked(message{Op: OpDisconnectServerMux, MuxID: m.MuxID})) - if !m.closed { - responses <- Response{Err: context.Cause(m.ctx)} - m.closeLocked() - } - return + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) + errState = true + continue case req, ok := <-requests: if !ok { // Done send EOF @@ -371,19 +392,14 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha msg.setZeroPayloadFlag() err := m.send(msg) if err != nil { - m.respMu.Lock() - responses <- Response{Err: err} - m.closeLocked() - m.respMu.Unlock() + m.addErrorNonBlockingClose(internalResp, err) } return } - if errState { - continue - } // Grab a send token. select { case <-m.ctx.Done(): + m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx)) errState = true continue case <-m.outBlock: @@ -398,8 +414,7 @@ func (m *muxClient) handleTwowayRequests(responses chan<- Response, requests cha err := m.send(msg) PutByteBuffer(req) if err != nil { - responses <- Response{Err: err} - m.close() + m.addErrorNonBlockingClose(internalResp, err) errState = true continue } @@ -534,6 +549,7 @@ func (m *muxClient) closeLocked() { if m.closed { return } + // We hold the lock, so nobody can modify m.respWait while we're closing. if m.respWait != nil { xioutil.SafeClose(m.respWait) m.respWait = nil diff --git a/internal/grid/stream.go b/internal/grid/stream.go index d65313d77..a99b66643 100644 --- a/internal/grid/stream.go +++ b/internal/grid/stream.go @@ -41,7 +41,8 @@ type Stream struct { // Requests sent cannot be used any further by the called. Requests chan<- []byte - ctx context.Context + muxID uint64 + ctx context.Context } // Send a payload to the remote server.