diff --git a/transport/amqp/publisher.go b/transport/amqp/publisher.go index a28ee941d..ea20e466f 100644 --- a/transport/amqp/publisher.go +++ b/transport/amqp/publisher.go @@ -15,13 +15,14 @@ const maxCorrelationIdLength = 255 // Publisher wraps an AMQP channel and queue, and provides a method that // implements endpoint.Endpoint. type Publisher struct { - ch Channel - q *amqp.Queue - enc EncodeRequestFunc - dec DecodeResponseFunc - before []RequestFunc - after []PublisherResponseFunc - timeout time.Duration + ch Channel + q *amqp.Queue + enc EncodeRequestFunc + dec DecodeResponseFunc + before []RequestFunc + after []PublisherResponseFunc + deliverer Deliverer + timeout time.Duration } // NewPublisher constructs a usable Publisher for a single remote method. @@ -33,11 +34,12 @@ func NewPublisher( options ...PublisherOption, ) *Publisher { p := &Publisher{ - ch: ch, - q: q, - enc: enc, - dec: dec, - timeout: 10 * time.Second, + ch: ch, + q: q, + enc: enc, + dec: dec, + deliverer: DefaultDeliverer, + timeout: 10 * time.Second, } for _, option := range options { option(p) @@ -61,6 +63,11 @@ func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { return func(p *Publisher) { p.after = append(p.after, after...) } } +// PublisherDeliverer sets the deliverer function that the Publisher invokes. +func PublisherDeliverer(deliverer Deliverer) PublisherOption { + return func(p *Publisher) { p.deliverer = deliverer } +} + // PublisherTimeout sets the available timeout for an AMQP request. func PublisherTimeout(timeout time.Duration) PublisherOption { return func(p *Publisher) { p.timeout = timeout } @@ -85,7 +92,7 @@ func (p Publisher) Endpoint() endpoint.Endpoint { ctx = f(ctx, &pub) } - deliv, err := p.publishAndConsumeFirstMatchingResponse(ctx, &pub) + deliv, err := p.deliverer(ctx, p, &pub) if err != nil { return nil, err } @@ -102,11 +109,20 @@ func (p Publisher) Endpoint() endpoint.Endpoint { } } -// publishAndConsumeFirstMatchingResponse publishes the specified Publishing +// Deliverer is invoked by the Publisher to publish the specified Publishing, and to +// retrieve the appropriate response Delivery object. +type Deliverer func( + context.Context, + Publisher, + *amqp.Publishing, +) (*amqp.Delivery, error) + +// DefaultDeliverer is a deliverer that publishes the specified Publishing // and returns the first Delivery object with the matching correlationId. // If the context times out while waiting for a reply, an error will be returned. -func (p Publisher) publishAndConsumeFirstMatchingResponse( +func DefaultDeliverer( ctx context.Context, + p Publisher, pub *amqp.Publishing, ) (*amqp.Delivery, error) { err := p.ch.Publish( @@ -150,3 +166,22 @@ func (p Publisher) publishAndConsumeFirstMatchingResponse( } } + +// SendAndForgetDeliverer delivers the supplied publishing and +// returns a nil response. +// When using this deliverer please ensure that the supplied DecodeResponseFunc and +// PublisherResponseFunc are able to handle nil-type responses. +func SendAndForgetDeliverer( + ctx context.Context, + p Publisher, + pub *amqp.Publishing, +) (*amqp.Delivery, error) { + err := p.ch.Publish( + getPublishExchange(ctx), + getPublishKey(ctx), + false, //mandatory + false, //immediate + *pub, + ) + return nil, err +} diff --git a/transport/amqp/publisher_test.go b/transport/amqp/publisher_test.go index 5b6785c84..2c62be10b 100644 --- a/transport/amqp/publisher_test.go +++ b/transport/amqp/publisher_test.go @@ -224,3 +224,47 @@ func TestSuccessfulPublisher(t *testing.T) { t.Errorf("want %s, have %s", want, have) } } + +// TestSendAndForgetPublisher tests that the SendAndForgetDeliverer is working +func TestSendAndForgetPublisher(t *testing.T) { + ch := &mockChannel{ + f: nullFunc, + c: make(chan amqp.Publishing, 1), + deliveries: []amqp.Delivery{}, // no reply from mock subscriber + } + q := &amqp.Queue{Name: "some queue"} + + pub := amqptransport.NewPublisher( + ch, + q, + func(context.Context, *amqp.Publishing, interface{}) error { return nil }, + func(context.Context, *amqp.Delivery) (response interface{}, err error) { + return struct{}{}, nil + }, + amqptransport.PublisherDeliverer(amqptransport.SendAndForgetDeliverer), + amqptransport.PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + finishChan := make(chan bool, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + if err != nil { + errChan <- err + } else { + finishChan <- true + } + + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + +} diff --git a/transport/amqp/subscriber.go b/transport/amqp/subscriber.go index 17e1b0f59..3bfe03b46 100644 --- a/transport/amqp/subscriber.go +++ b/transport/amqp/subscriber.go @@ -12,13 +12,14 @@ import ( // Subscriber wraps an endpoint and provides a handler for AMQP Delivery messages. type Subscriber struct { - e endpoint.Endpoint - dec DecodeRequestFunc - enc EncodeResponseFunc - before []RequestFunc - after []SubscriberResponseFunc - errorEncoder ErrorEncoder - logger log.Logger + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + before []RequestFunc + after []SubscriberResponseFunc + responsePublisher ResponsePublisher + errorEncoder ErrorEncoder + logger log.Logger } // NewSubscriber constructs a new subscriber, which provides a handler @@ -30,11 +31,12 @@ func NewSubscriber( options ...SubscriberOption, ) *Subscriber { s := &Subscriber{ - e: e, - dec: dec, - enc: enc, - errorEncoder: DefaultErrorEncoder, - logger: log.NewNopLogger(), + e: e, + dec: dec, + enc: enc, + responsePublisher: DefaultResponsePublisher, + errorEncoder: DefaultErrorEncoder, + logger: log.NewNopLogger(), } for _, option := range options { option(s) @@ -57,6 +59,13 @@ func SubscriberAfter(after ...SubscriberResponseFunc) SubscriberOption { return func(s *Subscriber) { s.after = append(s.after, after...) } } +// SubscriberResponsePublisher is used by the subscriber to deliver response +// objects to the original sender. +// By default, the DefaultResponsePublisher is used. +func SubscriberResponsePublisher(rp ResponsePublisher) SubscriberOption { + return func(s *Subscriber) { s.responsePublisher = rp } +} + // SubscriberErrorEncoder is used to encode errors to the subscriber reply // whenever they're encountered in the processing of a request. Clients can // use this to provide custom error formatting. By default, @@ -111,7 +120,7 @@ func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) { return } - if err := s.publishResponse(ctx, deliv, ch, &pub); err != nil { + if err := s.responsePublisher(ctx, deliv, ch, &pub); err != nil { s.logger.Log("err", err) s.errorEncoder(ctx, err, deliv, ch, &pub) return @@ -120,7 +129,45 @@ func (s Subscriber) ServeDelivery(ch Channel) func(deliv *amqp.Delivery) { } -func (s Subscriber) publishResponse( +// EncodeJSONResponse marshals the response as JSON as part of the +// payload of the AMQP Publishing object. +func EncodeJSONResponse( + ctx context.Context, + pub *amqp.Publishing, + response interface{}, +) error { + b, err := json.Marshal(response) + if err != nil { + return err + } + pub.Body = b + return nil +} + +// EncodeNopResponse is a response function that does nothing. +func EncodeNopResponse( + ctx context.Context, + pub *amqp.Publishing, + response interface{}, +) error { + return nil +} + +// ResponsePublisher functions are executed by the subscriber to +// publish response object to the original sender. +// Please note that the word "publisher" does not refer +// to the publisher of pub/sub. +// Rather, publisher is merely a function that publishes, or sends responses. +type ResponsePublisher func( + context.Context, + *amqp.Delivery, + Channel, + *amqp.Publishing, +) error + +// DefaultResponsePublisher extracts the reply exchange and reply key +// from the request, and sends the response object to that destination. +func DefaultResponsePublisher( ctx context.Context, deliv *amqp.Delivery, ch Channel, @@ -145,26 +192,14 @@ func (s Subscriber) publishResponse( ) } -// EncodeJSONResponse marshals the response as JSON as part of the -// payload of the AMQP Publishing object. -func EncodeJSONResponse( - ctx context.Context, - pub *amqp.Publishing, - response interface{}, -) error { - b, err := json.Marshal(response) - if err != nil { - return err - } - pub.Body = b - return nil -} - -// EncodeNopResponse is a response function that does nothing. -func EncodeNopResponse( +// NopResponsePublisher does not deliver a response to the original sender. +// This response publisher is used when the user wants the subscriber to +// receive and forget. +func NopResponsePublisher( ctx context.Context, + deliv *amqp.Delivery, + ch Channel, pub *amqp.Publishing, - response interface{}, ) error { return nil } diff --git a/transport/amqp/subscriber_test.go b/transport/amqp/subscriber_test.go index 5aece6b70..18b496af5 100644 --- a/transport/amqp/subscriber_test.go +++ b/transport/amqp/subscriber_test.go @@ -12,7 +12,7 @@ import ( ) var ( - typeAssertionError = errors.New("type assertion error") + errTypeAssertion = errors.New("type assertion error") ) // mockChannel is a mock of *amqp.Channel. @@ -205,7 +205,7 @@ func TestSubscriberSuccess(t *testing.T) { } res, ok := response.(testRes) if !ok { - t.Error(typeAssertionError) + t.Error(errTypeAssertion) } if want, have := obj.Squadron, res.Squadron; want != have { @@ -216,6 +216,45 @@ func TestSubscriberSuccess(t *testing.T) { } } +// TestNopResponseSubscriber checks if setting responsePublisher to +// NopResponsePublisher works properly by disabling response. +func TestNopResponseSubscriber(t *testing.T) { + cid := "correlation" + replyTo := "sender" + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + + sub := amqptransport.NewSubscriber( + testEndpoint, + testReqDecoder, + amqptransport.EncodeJSONResponse, + amqptransport.SubscriberResponsePublisher(amqptransport.NopResponsePublisher), + amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), + ) + + checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) {} + + outputChan := make(chan amqp.Publishing, 1) + ch := &mockChannel{f: checkReplyToFunc, c: outputChan} + sub.ServeDelivery(ch)(&amqp.Delivery{ + CorrelationId: cid, + ReplyTo: replyTo, + Body: b, + }) + + select { + case <-outputChan: + t.Fatal("Subscriber with NopResponsePublisher replied.") + case <-time.After(100 * time.Millisecond): + break + } +} + // TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode // are working. func TestSubscriberMultipleBefore(t *testing.T) { @@ -294,7 +333,7 @@ func TestDefaultContentMetaData(t *testing.T) { amqptransport.EncodeJSONResponse, amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder), ) - checkReplyToFunc := func(exch, k string, mandatory, immediate bool) { return } + checkReplyToFunc := func(exch, k string, mandatory, immediate bool) {} outputChan := make(chan amqp.Publishing, 1) ch := &mockChannel{f: checkReplyToFunc, c: outputChan} sub.ServeDelivery(ch)(&amqp.Delivery{}) @@ -344,7 +383,7 @@ type testRes struct { func testEndpoint(_ context.Context, request interface{}) (interface{}, error) { req, ok := request.(testReq) if !ok { - return nil, typeAssertionError + return nil, errTypeAssertion } name, prs := names[req.Squadron] if !prs {