From 35d281d7030c81d7b018f1b60d0c23aa83a14880 Mon Sep 17 00:00:00 2001 From: Nathan Smith Date: Mon, 27 May 2019 18:45:28 -0700 Subject: [PATCH] Use given context in NATS endpoint Previous implementation used a new background context which prevented context cancellation. --- transport/nats/publisher.go | 2 +- transport/nats/publisher_test.go | 38 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/transport/nats/publisher.go b/transport/nats/publisher.go index 7baab4d2f..5bc75fa61 100644 --- a/transport/nats/publisher.go +++ b/transport/nats/publisher.go @@ -64,7 +64,7 @@ func PublisherTimeout(timeout time.Duration) PublisherOption { // Endpoint returns a usable endpoint that invokes the remote endpoint. func (p Publisher) Endpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { - ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + ctx, cancel := context.WithTimeout(ctx, p.timeout) defer cancel() msg := nats.Msg{Subject: p.subject} diff --git a/transport/nats/publisher_test.go b/transport/nats/publisher_test.go index cbbe9d777..0272b1f15 100644 --- a/transport/nats/publisher_test.go +++ b/transport/nats/publisher_test.go @@ -186,6 +186,44 @@ func TestPublisherTimeout(t *testing.T) { } } +func TestPublisherCancellation(t *testing.T) { + var ( + testdata = "testdata" + encode = func(context.Context, *nats.Msg, interface{}) error { return nil } + decode = func(_ context.Context, msg *nats.Msg) (interface{}, error) { + return TestResponse{string(msg.Data), ""}, nil + } + ) + + nc := newNatsConn(t) + defer nc.Close() + + sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { + if err := nc.Publish(msg.Reply, []byte(testdata)); err != nil { + t.Fatal(err) + } + }) + if err != nil { + t.Fatal(err) + } + defer sub.Unsubscribe() + + publisher := natstransport.NewPublisher( + nc, + "natstransport.test", + encode, + decode, + ) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = publisher.Endpoint()(ctx, struct{}{}) + if err != context.Canceled { + t.Errorf("want %s, have %s", context.Canceled, err) + } +} + func TestEncodeJSONRequest(t *testing.T) { var data string