From 9536df609f4d0c75182c8af4b21b5c178461fbe6 Mon Sep 17 00:00:00 2001 From: Kevin Parsons Date: Thu, 3 Dec 2020 00:59:54 -0800 Subject: [PATCH] client/server: Don't block the main connection loop for transport IO Restructures both the client and server connection management so that sending messages on the transport is done by a separate "sender" goroutine. The receiving end was already split out like this. Without this change, it is possible for a send to block if the other end isn't reading fast enough, which then would block the main connection loop and prevent incoming messages from being processed. Signed-off-by: Kevin Parsons --- client.go | 51 ++++++++++++++++++++++++++++++++++++++++----------- server.go | 43 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 74 insertions(+), 20 deletions(-) diff --git a/client.go b/client.go index 30c9b73f3..7c87d4683 100644 --- a/client.go +++ b/client.go @@ -234,13 +234,19 @@ func (r *receiver) run(ctx context.Context, c *channel) { } func (c *Client) run() { + type streamCall struct { + streamID uint32 + call *callRequest + } var ( - streamID uint32 = 1 - waiters = make(map[uint32]*callRequest) - calls = c.calls - incoming = make(chan *message) - receiversDone = make(chan struct{}) - wg sync.WaitGroup + streamID uint32 = 1 + waiters = make(map[uint32]*callRequest) + calls = c.calls + requests = make(chan streamCall) + requestsFailed = make(chan streamCall) + incoming = make(chan *message) + receiversDone = make(chan struct{}) + wg sync.WaitGroup ) // broadcast the shutdown error to the remaining waiters. @@ -261,6 +267,21 @@ func (c *Client) run() { }() go recv.run(c.ctx, c.channel) + go func(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case streamCall := <-requests: + if err := c.send(streamCall.streamID, messageTypeRequest, streamCall.call.req); err != nil { + streamCall.call.errs <- err // errs is buffered so should not block. + requestsFailed <- streamCall + continue + } + } + } + }(c.ctx) + defer func() { c.conn.Close() c.userCloseFunc() @@ -270,13 +291,21 @@ func (c *Client) run() { for { select { case call := <-calls: - if err := c.send(streamID, messageTypeRequest, call.req); err != nil { - call.errs <- err - continue - } - + go func(streamID uint32, call *callRequest) { + sc := streamCall{ + streamID: streamID, + call: call, + } + select { + case <-c.ctx.Done(): + case requests <- sc: + } + }(streamID, call) waiters[streamID] = call streamID += 2 // enforce odd client initiated request ids + case streamCall := <-requestsFailed: + // Sending the request failed, so stop tracking this stream ID. + delete(waiters, streamCall.streamID) case msg := <-incoming: call, ok := waiters[msg.StreamID] if !ok { diff --git a/server.go b/server.go index c18b4e43b..4ecf8b322 100644 --- a/server.go +++ b/server.go @@ -318,6 +318,7 @@ func (c *serverConn) run(sctx context.Context) { active int state connState = connStateIdle responses = make(chan response) + responseErr = make(chan error) requests = make(chan request) recvErr = make(chan error, 1) shutdown = c.shutdown @@ -412,6 +413,36 @@ func (c *serverConn) run(sctx context.Context) { } }(recvErr) + go func(responseErr chan error) { + for { + select { + // We don't want a case for c.shutdown here, as that would cause us to exit + // immediately when it is signaled, rather than waiting for any active requests + // to complete first. Instead, once all the active requests have completed, + // the main loop will return and close done, which will cause us to exit as well. + case <-done: + return + case response := <-responses: + p, err := c.server.codec.Marshal(response.resp) + if err != nil { + logrus.WithError(err).Error("failed marshaling response") + responseErr <- err + return + } + + if err := ch.send(response.id, messageTypeResponse, p); err != nil { + logrus.WithError(err).Error("failed sending message on channel") + responseErr <- err + return + } + + // Send a nil error so that the main loop knows an active request has + // completed successfully. + responseErr <- nil + } + } + }(responseErr) + for { newstate := state switch { @@ -449,18 +480,12 @@ func (c *serverConn) run(sctx context.Context) { case <-done: } }(request.id) - case response := <-responses: - p, err := c.server.codec.Marshal(response.resp) + case err := <-responseErr: + // responseErr sends nil if no error occurred in sending the response. + // In that case we just decrement the active count and continue. if err != nil { - logrus.WithError(err).Error("failed marshaling response") return } - - if err := ch.send(response.id, messageTypeResponse, p); err != nil { - logrus.WithError(err).Error("failed sending message on channel") - return - } - active-- case err := <-recvErr: // TODO(stevvooe): Not wildly clear what we should do in this