Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 116 additions & 75 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,72 +194,131 @@ type message struct {
err error
}

type receiver struct {
wg *sync.WaitGroup
messages chan *message
err error
// callMap provides access to a map of active calls, guarded by a mutex.
type callMap struct {
m sync.Mutex
activeCalls map[uint32]*callRequest
closeErr error
}

func (r *receiver) run(ctx context.Context, c *channel) {
defer r.wg.Done()
// newCallMap returns a new callMap with an empty set of active calls.
func newCallMap() *callMap {
return &callMap{
activeCalls: make(map[uint32]*callRequest),
}
}

for {
select {
case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
r.err = filterCloseErr(err)
return
}
}
select {
case r.messages <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-ctx.Done():
r.err = ctx.Err()
return
}
}
// set adds a call entry to the map with the given streamID key.
func (cm *callMap) set(streamID uint32, cr *callRequest) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
cm.activeCalls[streamID] = cr
return nil
}

// get looks up the call entry for the given streamID key, then removes it
// from the map and returns it.
func (cm *callMap) get(streamID uint32) (cr *callRequest, ok bool, err error) {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return nil, false, cm.closeErr
}
cr, ok = cm.activeCalls[streamID]
if ok {
delete(cm.activeCalls, streamID)
}
return
}

// abort sends the given error to each active call, and clears the map.
// Once abort has been called, any subsequent calls to the callMap will return the error passed to abort.
func (cm *callMap) abort(err error) error {
cm.m.Lock()
defer cm.m.Unlock()
if cm.closeErr != nil {
return cm.closeErr
}
for streamID, call := range cm.activeCalls {
call.errs <- err
delete(cm.activeCalls, streamID)
}
cm.closeErr = err
return nil
}

func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
waiters = newCallMap()
receiverDone = make(chan struct{})
)

// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
// Sender goroutine
// Receives calls from dispatch, adds them to the set of active calls, and sends them
// to the server.
go func() {
var streamID uint32 = 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not entirely new, but since you've modified and may have fresh memory - how does that work if streamID is getting bigger than what uint32 can represent?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it will only be a problem if a request with a previous ID is still pending. Given that a uint32 means we can do 100 requests per second for 500 days before we overflow, I think it's not an issue in practice.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot we only use odd numbers for requests, so it's for 250 days, not 500. Still a long time though. :)

for {
select {
case <-c.ctx.Done():
return
Comment thread
kevpar marked this conversation as resolved.
case call := <-c.calls:
id := streamID
streamID += 2 // enforce odd client initiated request ids
if err := waiters.set(id, call); err != nil {
call.errs <- err // errs is buffered so should not block.
continue
}
if err := c.send(id, messageTypeRequest, call.req); err != nil {
call.errs <- err // errs is buffered so should not block.
waiters.get(id) // remove from waiters set
}
}
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)
}()

// Receiver goroutine
// Receives responses from the server, looks up the call info in the set of active calls,
// and notifies the caller of the response.
go func() {
wg.Wait()
close(receiversDone)
defer close(receiverDone)
for {
select {
case <-c.ctx.Done():
c.setError(c.ctx.Err())
return
default:
mh, p, err := c.channel.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
c.setError(filterCloseErr(err))
return
}
}
msg := &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}
call, ok, err := waiters.get(mh.StreamID)
if err != nil {
logrus.Errorf("ttrpc: failed to look up active call: %s", err)
continue
}
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", mh.StreamID)
continue
}
call.errs <- c.recv(call.resp, msg)
}
}
}()
go recv.run(c.ctx, c.channel)

defer func() {
c.conn.Close()
Expand All @@ -269,32 +328,14 @@ func (c *Client) run() {

for {
select {
case call := <-calls:
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}

waiters[streamID] = call
streamID += 2 // enforce odd client initiated request ids
case msg := <-incoming:
call, ok := waiters[msg.StreamID]
if !ok {
logrus.Errorf("ttrpc: received message for unknown channel %v", msg.StreamID)
continue
}

call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-receiversDone:
// all the receivers have exited
if recv.err != nil {
c.setError(recv.err)
}
case <-receiverDone:
// The receiver has exited.
// don't return out, let the close of the context trigger the abort of waiters
c.Close()
case <-c.ctx.Done():
abortWaiters(c.error())
// Abort all active calls. This will also prevent any new calls from being added
// to waiters.
waiters.abort(c.error())
return
}
}
Expand Down