Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,19 @@ func (r *receiver) run(ctx context.Context, c *channel) {
}

func (c *Client) run() {
type streamCall struct {
streamID uint32
call *callRequest
}
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
requests = make(chan streamCall)
requestsFailed = make(chan streamCall)
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
)

// broadcast the shutdown error to the remaining waiters.
Expand All @@ -261,6 +267,21 @@ func (c *Client) run() {
}()
go recv.run(c.ctx, c.channel)

go func(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case streamCall := <-requests:
if err := c.send(streamCall.streamID, messageTypeRequest, streamCall.call.req); err != nil {
streamCall.call.errs <- err // errs is buffered so should not block.
requestsFailed <- streamCall
continue
}
}
}
}(c.ctx)

defer func() {
c.conn.Close()
c.userCloseFunc()
Expand All @@ -270,13 +291,21 @@ func (c *Client) run() {
for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}

go func(streamID uint32, call *callRequest) {
sc := streamCall{
streamID: streamID,
call: call,
}
select {
case <-c.ctx.Done():
case requests <- sc:
}
}(streamID, call)
waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case streamCall := <-requestsFailed:
// Sending the request failed, so stop tracking this stream ID.
delete(waiters, streamCall.streamID)
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
Expand Down
43 changes: 34 additions & 9 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ func (c *serverConn) run(sctx context.Context) {
active int
state connState = connStateIdle
responses = make(chan response)
responseErr = make(chan error)
requests = make(chan request)
recvErr = make(chan error, 1)
shutdown = c.shutdown
Expand Down Expand Up @@ -412,6 +413,36 @@ func (c *serverConn) run(sctx context.Context) {
}
}(recvErr)

go func(responseErr chan error) {
Comment thread
kevpar marked this conversation as resolved.
for {
select {
// We don't want a case for c.shutdown here, as that would cause us to exit
// immediately when it is signaled, rather than waiting for any active requests
// to complete first. Instead, once all the active requests have completed,
// the main loop will return and close done, which will cause us to exit as well.
case <-done:
return
case response := <-responses:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be slightly clearer to defer close(responses) and just have this be for response := range responses.

Copy link
Copy Markdown

@jstarks jstarks Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, maybe that's not practical since responses might still be referenced in the call goroutine.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, generally you don't want to close from the read side.

p, err := c.server.codec.Marshal(response.resp)
if err != nil {
logrus.WithError(err).Error("failed marshaling response")
responseErr <- err
return
}

if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
responseErr <- err
return
}

// Send a nil error so that the main loop knows an active request has
// completed successfully.
responseErr <- nil
}
}
}(responseErr)

for {
newstate := state
switch {
Expand Down Expand Up @@ -449,18 +480,12 @@ func (c *serverConn) run(sctx context.Context) {
case <-done:
}
}(request.id)
case response := <-responses:
p, err := c.server.codec.Marshal(response.resp)
case err := <-responseErr:
// responseErr sends nil if no error occurred in sending the response.
// In that case we just decrement the active count and continue.
if err != nil {
logrus.WithError(err).Error("failed marshaling response")
return
}

if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
return
}

active--
case err := <-recvErr:
// TODO(stevvooe): Not wildly clear what we should do in this
Expand Down