Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ttrpc

import (
"bufio"
"context"
"encoding/binary"
"io"
"net"
Expand Down Expand Up @@ -98,7 +97,7 @@ func newChannel(conn net.Conn) *channel {
// returned will be valid and caller should send that along to
// the correct consumer. The bytes on the underlying channel
// will be discarded.
func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
func (ch *channel) recv() (messageHeader, []byte, error) {
mh, err := readMessageHeader(ch.hrbuf[:], ch.br)
if err != nil {
return messageHeader{}, nil, err
Expand All @@ -120,7 +119,7 @@ func (ch *channel) recv(ctx context.Context) (messageHeader, []byte, error) {
return mh, p, nil
}

func (ch *channel) send(ctx context.Context, streamID uint32, t messageType, p []byte) error {
func (ch *channel) send(streamID uint32, t messageType, p []byte) error {
if err := writeMessageHeader(ch.bw, ch.hwbuf[:], messageHeader{Length: uint32(len(p)), StreamID: streamID, Type: t}); err != nil {
return err
}
Expand Down
11 changes: 4 additions & 7 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ttrpc

import (
"bytes"
"context"
"io"
"net"
"reflect"
Expand All @@ -31,7 +30,6 @@ import (

func TestReadWriteMessage(t *testing.T) {
var (
ctx = context.Background()
w, r = net.Pipe()
ch = newChannel(w)
rch = newChannel(r)
Expand All @@ -46,7 +44,7 @@ func TestReadWriteMessage(t *testing.T) {

go func() {
for i, msg := range messages {
if err := ch.send(ctx, uint32(i), 1, msg); err != nil {
if err := ch.send(uint32(i), 1, msg); err != nil {
errs <- err
return
}
Expand All @@ -56,7 +54,7 @@ func TestReadWriteMessage(t *testing.T) {
}()

for {
_, p, err := rch.recv(ctx)
_, p, err := rch.recv()
if err != nil {
if errors.Cause(err) != io.EOF {
t.Fatal(err)
Expand Down Expand Up @@ -91,20 +89,19 @@ func TestReadWriteMessage(t *testing.T) {

func TestMessageOversize(t *testing.T) {
var (
ctx = context.Background()
w, r = net.Pipe()
wch, rch = newChannel(w), newChannel(r)
msg = bytes.Repeat([]byte("a message of massive length"), 512<<10)
errs = make(chan error, 1)
)

go func() {
if err := wch.send(ctx, 1, 1, msg); err != nil {
if err := wch.send(1, 1, msg); err != nil {
errs <- err
}
}()

_, _, err := rch.recv(ctx)
_, _, err := rch.recv()
if err == nil {
t.Fatalf("error expected reading with small buffer")
}
Expand Down
186 changes: 110 additions & 76 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ type Client struct {
channel *channel
calls chan *callRequest

closed chan struct{}
closeOnce sync.Once
closeFunc func()
done chan struct{}
ctx context.Context
closed func()

closeOnce sync.Once
userCloseFunc func()

errOnce sync.Once
err error
interceptor UnaryClientInterceptor
}
Expand All @@ -57,7 +60,7 @@ type ClientOpts func(c *Client)
// WithOnClose sets the close func whenever the client's Close() method is called
func WithOnClose(onClose func()) ClientOpts {
return func(c *Client) {
c.closeFunc = onClose
c.userCloseFunc = onClose
}
}

Expand All @@ -69,15 +72,16 @@ func WithUnaryClientInterceptor(i UnaryClientInterceptor) ClientOpts {
}

func NewClient(conn net.Conn, opts ...ClientOpts) *Client {
ctx, cancel := context.WithCancel(context.Background())
c := &Client{
codec: codec{},
conn: conn,
channel: newChannel(conn),
calls: make(chan *callRequest),
closed: make(chan struct{}),
done: make(chan struct{}),
closeFunc: func() {},
interceptor: defaultClientInterceptor,
codec: codec{},
conn: conn,
channel: newChannel(conn),
calls: make(chan *callRequest),
closed: cancel,
ctx: ctx,
userCloseFunc: func() {},
interceptor: defaultClientInterceptor,
}

for _, o := range opts {
Expand Down Expand Up @@ -150,25 +154,24 @@ func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) err
case <-ctx.Done():
return ctx.Err()
case c.calls <- call:
case <-c.done:
return c.err
case <-c.ctx.Done():
return c.error()
}

select {
case <-ctx.Done():
return ctx.Err()
case err := <-errs:
return filterCloseErr(err)
case <-c.done:
return c.err
case <-c.ctx.Done():
return c.error()
}
}

func (c *Client) Close() error {
c.closeOnce.Do(func() {
close(c.closed)
c.closed()
})

return nil
}

Expand All @@ -178,51 +181,82 @@ type message struct {
err error
}

func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
shutdown = make(chan struct{})
shutdownErr error
)
type receiver struct {
wg *sync.WaitGroup
messages chan *message
err error
}

go func() {
defer close(shutdown)
func (r *receiver) run(ctx context.Context, c *channel) {
defer r.wg.Done()

// start one more goroutine to recv messages without blocking.
for {
mh, p, err := c.channel.recv(context.TODO())
for {
select {
case <-ctx.Done():
r.err = ctx.Err()
return
default:
mh, p, err := c.recv()
if err != nil {
_, ok := status.FromError(err)
if !ok {
// treat all errors that are not an rpc status as terminal.
// all others poison the connection.
shutdownErr = err
r.err = filterCloseErr(err)
return
}
}
select {
case incoming <- &message{
case r.messages <- &message{
messageHeader: mh,
p: p[:mh.Length],
err: err,
}:
case <-c.done:
case <-ctx.Done():
r.err = ctx.Err()
return
}
}
}
}

func (c *Client) run() {
var (
streamID uint32 = 1
waiters = make(map[uint32]*callRequest)
calls = c.calls
incoming = make(chan *message)
receiversDone = make(chan struct{})
wg sync.WaitGroup
)

// broadcast the shutdown error to the remaining waiters.
abortWaiters := func(wErr error) {
for _, waiter := range waiters {
waiter.errs <- wErr
}
}
recv := &receiver{
wg: &wg,
messages: incoming,
}
wg.Add(1)

go func() {
wg.Wait()
close(receiversDone)
}()
go recv.run(c.ctx, c.channel)

defer c.conn.Close()
defer close(c.done)
defer c.closeFunc()
defer func() {
c.conn.Close()
c.userCloseFunc()
}()

for {
select {
case call := <-calls:
if err := c.send(call.ctx, streamID, messageTypeRequest, call.req); err != nil {
if err := c.send(streamID, messageTypeRequest, call.req); err != nil {
call.errs <- err
continue
}
Expand All @@ -238,41 +272,42 @@ func (c *Client) run() {

call.errs <- c.recv(call.resp, msg)
delete(waiters, msg.StreamID)
case <-shutdown:
if shutdownErr != nil {
shutdownErr = filterCloseErr(shutdownErr)
} else {
shutdownErr = ErrClosed
}

shutdownErr = errors.Wrapf(shutdownErr, "ttrpc: client shutting down")

c.err = shutdownErr
for _, waiter := range waiters {
waiter.errs <- shutdownErr
case <-receiversDone:
// all the receivers have exited
if recv.err != nil {
c.setError(recv.err)
}
// don't return out, let the close of the context trigger the abort of waiters
c.Close()
return
case <-c.closed:
if c.err == nil {
c.err = ErrClosed
}
// broadcast the shutdown error to the remaining waiters.
for _, waiter := range waiters {
waiter.errs <- c.err
}
case <-c.ctx.Done():
abortWaiters(c.error())
return
}
}
}

func (c *Client) send(ctx context.Context, streamID uint32, mtype messageType, msg interface{}) error {
func (c *Client) error() error {
c.errOnce.Do(func() {
if c.err == nil {
c.err = ErrClosed
}
})
return c.err
}

func (c *Client) setError(err error) {
c.errOnce.Do(func() {
c.err = err
})
}

func (c *Client) send(streamID uint32, mtype messageType, msg interface{}) error {
p, err := c.codec.Marshal(msg)
if err != nil {
return err
}

return c.channel.send(ctx, streamID, mtype, p)
return c.channel.send(streamID, mtype, p)
}

func (c *Client) recv(resp *Response, msg *message) error {
Expand All @@ -293,22 +328,21 @@ func (c *Client) recv(resp *Response, msg *message) error {
//
// This purposely ignores errors with a wrapped cause.
func filterCloseErr(err error) error {
if err == nil {
switch {
case err == nil:
return nil
}

if err == io.EOF {
case err == io.EOF:
return ErrClosed
}

if strings.Contains(err.Error(), "use of closed network connection") {
case errors.Cause(err) == io.EOF:
return ErrClosed
}

// if we have an epipe on a write, we cast to errclosed
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
return ErrClosed
case strings.Contains(err.Error(), "use of closed network connection"):
return ErrClosed
default:
// if we have an epipe on a write, we cast to errclosed
if oerr, ok := err.(*net.OpError); ok && oerr.Op == "write" {
if serr, ok := oerr.Err.(*os.SyscallError); ok && serr.Err == syscall.EPIPE {
return ErrClosed
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ func (c *serverConn) run(sctx context.Context) {
default: // proceed
}

mh, p, err := ch.recv(ctx)
mh, p, err := ch.recv()
if err != nil {
status, ok := status.FromError(err)
if !ok {
Expand Down Expand Up @@ -441,7 +441,7 @@ func (c *serverConn) run(sctx context.Context) {
return
}

if err := ch.send(ctx, response.id, messageTypeResponse, p); err != nil {
if err := ch.send(response.id, messageTypeResponse, p); err != nil {
logrus.WithError(err).Error("failed sending message on channel")
return
}
Expand Down
2 changes: 1 addition & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func TestClientEOF(t *testing.T) {
}

// shutdown the server so the client stops receiving stuff.
if err := server.Shutdown(ctx); err != nil {
if err := server.Close(); err != nil {
t.Fatal(err)
}
if err := <-errs; err != ErrServerClosed {
Expand Down