diff --git a/transport/nats/publisher_test.go b/transport/nats/publisher_test.go index 8468f1b2f..cbbe9d777 100644 --- a/transport/nats/publisher_test.go +++ b/transport/nats/publisher_test.go @@ -1,13 +1,13 @@ package nats_test import ( - "testing" "context" - "time" "strings" + "testing" + "time" - "github.com/nats-io/go-nats" natstransport "github.com/go-kit/kit/transport/nats" + "github.com/nats-io/go-nats" ) func TestPublisher(t *testing.T) { @@ -19,10 +19,7 @@ func TestPublisher(t *testing.T) { } ) - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { @@ -66,10 +63,7 @@ func TestPublisherBefore(t *testing.T) { } ) - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { @@ -117,10 +111,7 @@ func TestPublisherAfter(t *testing.T) { } ) - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { @@ -167,10 +158,7 @@ func TestPublisherTimeout(t *testing.T) { } ) - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() ch := make(chan struct{}) @@ -195,18 +183,13 @@ func TestPublisherTimeout(t *testing.T) { _, err = publisher.Endpoint()(context.Background(), struct{}{}) if err != context.DeadlineExceeded { t.Errorf("want %s, have %s", context.DeadlineExceeded, err) - } - } func TestEncodeJSONRequest(t *testing.T) { var data string - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", func(msg *nats.Msg) { @@ -237,7 +220,9 @@ func TestEncodeJSONRequest(t *testing.T) { {1.2, "1.2"}, {true, "true"}, {"test", "\"test\""}, - {struct{ Foo string `json:"foo"` }{"foo"}, "{\"foo\":\"foo\"}"}, + {struct { + Foo string `json:"foo"` + }{"foo"}, "{\"foo\":\"foo\"}"}, } { if _, err := publisher(context.Background(), test.value); err != nil { t.Fatal(err) diff --git a/transport/nats/subscriber_test.go b/transport/nats/subscriber_test.go index a2ba160f3..4d44dbbd0 100644 --- a/transport/nats/subscriber_test.go +++ b/transport/nats/subscriber_test.go @@ -1,19 +1,19 @@ package nats_test import ( - "testing" "context" + "encoding/json" "errors" - "time" - "sync" "strings" - "encoding/json" + "sync" + "testing" + "time" - "github.com/nats-io/go-nats" "github.com/nats-io/gnatsd/server" + "github.com/nats-io/go-nats" - natstransport "github.com/go-kit/kit/transport/nats" "github.com/go-kit/kit/endpoint" + natstransport "github.com/go-kit/kit/transport/nats" ) type TestResponse struct { @@ -21,9 +21,13 @@ type TestResponse struct { Error string `json:"err"` } +var natsServer *server.Server + func init() { - opts := server.Options{Host: "localhost", Port: 4222} - natsServer := server.New(&opts) + natsServer = server.New(&server.Options{ + Host: "localhost", + Port: 4222, + }) go func() { natsServer.Start() @@ -34,11 +38,32 @@ func init() { } } -func TestSubscriberBadDecode(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) +func newNatsConn(t *testing.T) *nats.Conn { + // Subscriptions and connections are closed asynchronously, so it's possible + // that there's still a subscription from an old connection that must be closed + // before the current test can be run. + for tries := 20; tries > 0; tries-- { + if natsServer.NumSubscriptions() == 0 { + break + } + + time.Sleep(5 * time.Millisecond) + } + + if n := natsServer.NumSubscriptions(); n > 0 { + t.Fatalf("found %d active subscriptions on the server", n) + } + + nc, err := nats.Connect("nats://"+natsServer.Addr().String(), nats.Name(t.Name())) if err != nil { - t.Fatal(err) + t.Fatalf("failed to connect to gnatsd server: %s", err) } + + return nc +} + +func TestSubscriberBadDecode(t *testing.T) { + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( @@ -56,10 +81,7 @@ func TestSubscriberBadDecode(t *testing.T) { } func TestSubscriberBadEndpoint(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( @@ -76,10 +98,7 @@ func TestSubscriberBadEndpoint(t *testing.T) { } func TestSubscriberBadEncode(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( @@ -96,10 +115,7 @@ func TestSubscriberBadEncode(t *testing.T) { } func TestSubscriberErrorEncoder(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() errTeapot := errors.New("teapot") @@ -152,10 +168,7 @@ func TestSubscriberHappySubject(t *testing.T) { } func TestMultipleSubscriberBefore(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() var ( @@ -216,10 +229,7 @@ func TestMultipleSubscriberBefore(t *testing.T) { } func TestMultipleSubscriberAfter(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() var ( @@ -280,14 +290,15 @@ func TestMultipleSubscriberAfter(t *testing.T) { } func TestEncodeJSONResponse(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( - func(context.Context, interface{}) (interface{}, error) { return struct{ Foo string `json:"foo"` }{"bar"}, nil }, + func(context.Context, interface{}) (interface{}, error) { + return struct { + Foo string `json:"foo"` + }{"bar"}, nil + }, func(context.Context, *nats.Msg) (interface{}, error) { return struct{}{}, nil }, natstransport.EncodeJSONResponse, ) @@ -317,13 +328,12 @@ func (m responseError) Error() string { } func TestErrorEncoder(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() - errResp := struct{ Error string `json:"err"` }{"oh no"} + errResp := struct { + Error string `json:"err"` + }{"oh no"} handler := natstransport.NewSubscriber( func(context.Context, interface{}) (interface{}, error) { return nil, responseError{msg: errResp.Error} @@ -355,10 +365,7 @@ func TestErrorEncoder(t *testing.T) { type noContentResponse struct{} func TestEncodeNoContent(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( @@ -384,10 +391,7 @@ func TestEncodeNoContent(t *testing.T) { } func TestNoOpRequestDecoder(t *testing.T) { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() handler := natstransport.NewSubscriber( @@ -420,7 +424,10 @@ func TestNoOpRequestDecoder(t *testing.T) { func testSubscriber(t *testing.T) (step func(), resp <-chan *nats.Msg) { var ( stepch = make(chan bool) - endpoint = func(context.Context, interface{}) (interface{}, error) { <-stepch; return struct{}{}, nil } + endpoint = func(context.Context, interface{}) (interface{}, error) { + <-stepch + return struct{}{}, nil + } response = make(chan *nats.Msg) handler = natstransport.NewSubscriber( endpoint, @@ -432,10 +439,7 @@ func testSubscriber(t *testing.T) (step func(), resp <-chan *nats.Msg) { ) go func() { - nc, err := nats.Connect(nats.DefaultURL) - if err != nil { - t.Fatal(err) - } + nc := newNatsConn(t) defer nc.Close() sub, err := nc.QueueSubscribe("natstransport.test", "natstransport", handler.ServeMsg(nc))