diff --git a/examples/addsvc/cmd/addcli/addcli.go b/examples/addsvc/cmd/addcli/addcli.go index fe24fc278..7e4aab4fb 100644 --- a/examples/addsvc/cmd/addcli/addcli.go +++ b/examples/addsvc/cmd/addcli/addcli.go @@ -37,6 +37,7 @@ func main() { httpAddr = fs.String("http-addr", "", "HTTP address of addsvc") grpcAddr = fs.String("grpc-addr", "", "gRPC address of addsvc") thriftAddr = fs.String("thrift-addr", "", "Thrift address of addsvc") + jsonRPCAddr = fs.String("jsonrpc-addr", "", "JSON RPC address of addsvc") thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson") thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered") thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing") @@ -102,6 +103,8 @@ func main() { } defer conn.Close() svc = addtransport.NewGRPCClient(conn, tracer, log.NewNopLogger()) + } else if *jsonRPCAddr != "" { + svc, err = addtransport.NewJSONRPCClient(*jsonRPCAddr, tracer, log.NewNopLogger()) } else if *thriftAddr != "" { // It's necessary to do all of this construction in the func main, // because (among other reasons) we need to control the lifecycle of the diff --git a/examples/addsvc/cmd/addsvc/addsvc.go b/examples/addsvc/cmd/addsvc/addsvc.go index b1886e2f7..71fe836b7 100644 --- a/examples/addsvc/cmd/addsvc/addsvc.go +++ b/examples/addsvc/cmd/addsvc/addsvc.go @@ -42,6 +42,7 @@ func main() { httpAddr = fs.String("http-addr", ":8081", "HTTP listen address") grpcAddr = fs.String("grpc-addr", ":8082", "gRPC listen address") thriftAddr = fs.String("thrift-addr", ":8083", "Thrift listen address") + jsonRPCAddr = fs.String("jsonrpc-addr", ":8084", "JSON RPC listen address") thriftProtocol = fs.String("thrift-protocol", "binary", "binary, compact, json, simplejson") thriftBuffer = fs.Int("thrift-buffer", 0, "0 for unbuffered") thriftFramed = fs.Bool("thrift-framed", false, "true to enable framing") @@ -135,11 +136,12 @@ func main() { // the interfaces that the transports expect. Note that we're not binding // them to ports or anything yet; we'll do that next. var ( - service = addservice.New(logger, ints, chars) - endpoints = addendpoint.New(service, logger, duration, tracer) - httpHandler = addtransport.NewHTTPHandler(endpoints, tracer, logger) - grpcServer = addtransport.NewGRPCServer(endpoints, tracer, logger) - thriftServer = addtransport.NewThriftServer(endpoints) + service = addservice.New(logger, ints, chars) + endpoints = addendpoint.New(service, logger, duration, tracer) + httpHandler = addtransport.NewHTTPHandler(endpoints, tracer, logger) + grpcServer = addtransport.NewGRPCServer(endpoints, tracer, logger) + thriftServer = addtransport.NewThriftServer(endpoints) + jsonrpcHandler = addtransport.NewJSONRPCHandler(endpoints, logger) ) // Now we're to the part of the func main where we want to start actually @@ -244,6 +246,19 @@ func main() { thriftSocket.Close() }) } + { + httpListener, err := net.Listen("tcp", *jsonRPCAddr) + if err != nil { + logger.Log("transport", "JSONRPC over HTTP", "during", "Listen", "err", err) + os.Exit(1) + } + g.Add(func() error { + logger.Log("transport", "JSONRPC over HTTP", "addr", *jsonRPCAddr) + return http.Serve(httpListener, jsonrpcHandler) + }, func(error) { + httpListener.Close() + }) + } { // This function just sits and waits for ctrl-C. cancelInterrupt := make(chan struct{}) diff --git a/examples/addsvc/pkg/addtransport/jsonrpc.go b/examples/addsvc/pkg/addtransport/jsonrpc.go new file mode 100644 index 000000000..9508e81e1 --- /dev/null +++ b/examples/addsvc/pkg/addtransport/jsonrpc.go @@ -0,0 +1,207 @@ +package addtransport + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strings" + "time" + + "golang.org/x/time/rate" + + "github.com/go-kit/kit/circuitbreaker" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/examples/addsvc/pkg/addendpoint" + "github.com/go-kit/kit/examples/addsvc/pkg/addservice" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/ratelimit" + "github.com/go-kit/kit/tracing/opentracing" + "github.com/go-kit/kit/transport/http/jsonrpc" + stdopentracing "github.com/opentracing/opentracing-go" + "github.com/sony/gobreaker" +) + +// NewJSONRPCHandler returns a JSON RPC Server/Handler that can be passed to http.Handle() +func NewJSONRPCHandler(endpoints addendpoint.Set, logger log.Logger) *jsonrpc.Server { + handler := jsonrpc.NewServer( + makeEndpointCodecMap(endpoints), + jsonrpc.ServerErrorLogger(logger), + ) + return handler +} + +// NewJSONRPCClient returns an addservice backed by a JSON RPC over HTTP server +// living at the remote instance. We expect instance to come from a service +// discovery system, so likely of the form "host:port". We bake-in certain +// middlewares, implementing the client library pattern. +func NewJSONRPCClient(instance string, tracer stdopentracing.Tracer, logger log.Logger) (addservice.Service, error) { + // Quickly sanitize the instance string. + if !strings.HasPrefix(instance, "http") { + instance = "http://" + instance + } + u, err := url.Parse(instance) + if err != nil { + return nil, err + } + + // We construct a single ratelimiter middleware, to limit the total outgoing + // QPS from this client to all methods on the remote instance. We also + // construct per-endpoint circuitbreaker middlewares to demonstrate how + // that's done, although they could easily be combined into a single breaker + // for the entire remote instance, too. + limiter := ratelimit.NewErroringLimiter(rate.NewLimiter(rate.Every(time.Second), 100)) + + var sumEndpoint endpoint.Endpoint + { + sumEndpoint = jsonrpc.NewClient( + u, + "sum", + jsonrpc.ClientRequestEncoder(encodeSumRequest), + jsonrpc.ClientResponseDecoder(decodeSumResponse), + ).Endpoint() + sumEndpoint = opentracing.TraceClient(tracer, "Sum")(sumEndpoint) + sumEndpoint = limiter(sumEndpoint) + sumEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Sum", + Timeout: 30 * time.Second, + }))(sumEndpoint) + } + + var concatEndpoint endpoint.Endpoint + { + concatEndpoint = jsonrpc.NewClient( + u, + "concat", + jsonrpc.ClientRequestEncoder(encodeConcatRequest), + jsonrpc.ClientResponseDecoder(decodeConcatResponse), + ).Endpoint() + concatEndpoint = opentracing.TraceClient(tracer, "Concat")(concatEndpoint) + concatEndpoint = limiter(concatEndpoint) + concatEndpoint = circuitbreaker.Gobreaker(gobreaker.NewCircuitBreaker(gobreaker.Settings{ + Name: "Concat", + Timeout: 30 * time.Second, + }))(concatEndpoint) + } + + // Returning the endpoint.Set as a service.Service relies on the + // endpoint.Set implementing the Service methods. That's just a simple bit + // of glue code. + return addendpoint.Set{ + SumEndpoint: sumEndpoint, + ConcatEndpoint: concatEndpoint, + }, nil + +} + +// makeEndpointCodecMap returns a codec map configured for the addsvc. +func makeEndpointCodecMap(endpoints addendpoint.Set) jsonrpc.EndpointCodecMap { + return jsonrpc.EndpointCodecMap{ + "sum": jsonrpc.EndpointCodec{ + Endpoint: endpoints.SumEndpoint, + Decode: decodeSumRequest, + Encode: encodeSumResponse, + }, + "concat": jsonrpc.EndpointCodec{ + Endpoint: endpoints.ConcatEndpoint, + Decode: decodeConcatRequest, + Encode: encodeConcatResponse, + }, + } +} + +func decodeSumRequest(_ context.Context, msg json.RawMessage) (interface{}, error) { + var req addendpoint.SumRequest + err := json.Unmarshal(msg, &req) + if err != nil { + return nil, &jsonrpc.Error{ + Code: -32000, + Message: fmt.Sprintf("couldn't unmarshal body to sum request: %s", err), + } + } + return req, nil +} + +func encodeSumResponse(_ context.Context, obj interface{}) (json.RawMessage, error) { + res, ok := obj.(addendpoint.SumResponse) + if !ok { + return nil, &jsonrpc.Error{ + Code: -32000, + Message: fmt.Sprintf("Asserting result to *SumResponse failed. Got %T, %+v", obj, obj), + } + } + b, err := json.Marshal(res) + if err != nil { + return nil, fmt.Errorf("couldn't marshal response: %s", err) + } + return b, nil +} + +func decodeSumResponse(_ context.Context, msg json.RawMessage) (interface{}, error) { + var res addendpoint.SumResponse + err := json.Unmarshal(msg, &res) + if err != nil { + return nil, fmt.Errorf("couldn't unmarshal body to SumResponse: %s", err) + } + return res, nil +} + +func encodeSumRequest(_ context.Context, obj interface{}) (json.RawMessage, error) { + req, ok := obj.(addendpoint.SumRequest) + if !ok { + return nil, fmt.Errorf("couldn't assert request as SumRequest, got %T", obj) + } + b, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("couldn't marshal request: %s", err) + } + return b, nil +} + +func decodeConcatRequest(_ context.Context, msg json.RawMessage) (interface{}, error) { + var req addendpoint.ConcatRequest + err := json.Unmarshal(msg, &req) + if err != nil { + return nil, &jsonrpc.Error{ + Code: -32000, + Message: fmt.Sprintf("couldn't unmarshal body to concat request: %s", err), + } + } + return req, nil +} + +func encodeConcatResponse(_ context.Context, obj interface{}) (json.RawMessage, error) { + res, ok := obj.(addendpoint.ConcatResponse) + if !ok { + return nil, &jsonrpc.Error{ + Code: -32000, + Message: fmt.Sprintf("Asserting result to *ConcatResponse failed. Got %T, %+v", obj, obj), + } + } + b, err := json.Marshal(res) + if err != nil { + return nil, fmt.Errorf("couldn't marshal response: %s", err) + } + return b, nil +} + +func decodeConcatResponse(_ context.Context, msg json.RawMessage) (interface{}, error) { + var res addendpoint.ConcatResponse + err := json.Unmarshal(msg, &res) + if err != nil { + return nil, fmt.Errorf("couldn't unmarshal body to ConcatResponse: %s", err) + } + return res, nil +} + +func encodeConcatRequest(_ context.Context, obj interface{}) (json.RawMessage, error) { + req, ok := obj.(addendpoint.ConcatRequest) + if !ok { + return nil, fmt.Errorf("couldn't assert request as ConcatRequest, got %T", obj) + } + b, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("couldn't marshal request: %s", err) + } + return b, nil +} diff --git a/transport/http/client.go b/transport/http/client.go index f1ca9c3a4..25c078a58 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -143,8 +143,8 @@ func (c Client) Endpoint() endpoint.Endpoint { // request, after the response is returned. The principal // intended use is for error logging. Additional response parameters are // provided in the context under keys with the ContextKeyResponse prefix. -// Note: err may be nil. There maybe also no additional response parameters depending on -// when an error occurs. +// Note: err may be nil. There maybe also no additional response parameters +// depending on when an error occurs. type ClientFinalizerFunc func(ctx context.Context, err error) // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a diff --git a/transport/http/jsonrpc/README.md b/transport/http/jsonrpc/README.md new file mode 100644 index 000000000..d4140faff --- /dev/null +++ b/transport/http/jsonrpc/README.md @@ -0,0 +1,92 @@ +# JSON RPC + +[JSON RPC](http://www.jsonrpc.org) is "A light weight remote procedure call protocol". It allows for the creation of simple RPC-style APIs with human-readable messages that are front-end friendly. + +## Using JSON RPC with Go-Kit +Using JSON RPC and go-kit together is quite simple. + +A JSON RPC _server_ acts as an [HTTP Handler](https://godoc.org/net/http#Handler), receiving all requests to the JSON RPC's URL. The server looks at the `method` property of the [Request Object](http://www.jsonrpc.org/specification#request_object), and routes it to the corresponding code. + +Each JSON RPC _method_ is implemented as an `EndpointCodec`, a go-kit [Endpoint](https://godoc.org/github.com/go-kit/kit/endpoint#Endpoint), sandwiched between a decoder and encoder. The decoder picks apart the JSON RPC request params, which can be passed to your endpoint. The encoder receives the output from the endpoint and encodes a JSON-RPC result. + +## Example — Add Service +Let's say we want a service that adds two ints together. We'll serve this at `http://localhost/rpc`. So a request to our `sum` method will be a POST to `http://localhost/rpc` with a request body of: + + { + "id": 123, + "jsonrpc": "2.0", + "method": "sum", + "params": { + "A": 2, + "B": 2 + } + } + +### `EndpointCodecMap` +The routing table for incoming JSON RPC requests is the `EndpointCodecMap`. The key of the map is the JSON RPC method name. Here, we're routing the `sum` method to an `EndpointCodec` wrapped around `sumEndpoint`. + + jsonrpc.EndpointCodecMap{ + "sum": jsonrpc.EndpointCodec{ + Endpoint: sumEndpoint, + Decode: decodeSumRequest, + Encode: encodeSumResponse, + }, + } + +### Decoder + type DecodeRequestFunc func(context.Context, json.RawMessage) (request interface{}, err error) + +A `DecodeRequestFunc` is given the raw JSON from the `params` property of the Request object, _not_ the whole request object. It returns an object that will be the input to the Endpoint. For our purposes, the output should be a SumRequest, like this: + + type SumRequest struct { + A, B int + } + +So here's our decoder: + + func decodeSumRequest(ctx context.Context, msg json.RawMessage) (interface{}, error) { + var req SumRequest + err := json.Unmarshal(msg, &req) + if err != nil { + return nil, err + } + return req, nil + } + +So our `SumRequest` will now be passed to the endpoint. Once the endpoint has done its work, we hand over to the… + +### Encoder +The encoder takes the output of the endpoint, and builds the raw JSON message that will form the `result` field of a [Response Object](http://www.jsonrpc.org/specification#response_object). Our result is going to be a plain int. Here's our encoder: + + func encodeSumResponse(ctx context.Context, result interface{}) (json.RawMessage, error) { + sum, ok := result.(int) + if !ok { + return nil, errors.New("result is not an int") + } + b, err := json.Marshal(sum) + if err != nil { + return nil, err + } + return b, nil + } + +### Server +Now that we have an EndpointCodec with decoder, endpoint, and encoder, we can wire up the server: + + handler := jsonrpc.NewServer(jsonrpc.EndpointCodecMap{ + "sum": jsonrpc.EndpointCodec{ + Endpoint: sumEndpoint, + Decode: decodeSumRequest, + Encode: encodeSumResponse, + }, + }) + http.Handle("/rpc", handler) + http.ListenAndServe(":80", nil) + +With all of this done, our example request above should result in a response like this: + + { + "jsonrpc": "2.0", + "result": 4, + "error": null + } diff --git a/transport/http/jsonrpc/client.go b/transport/http/jsonrpc/client.go new file mode 100644 index 000000000..ca57bbf1d --- /dev/null +++ b/transport/http/jsonrpc/client.go @@ -0,0 +1,236 @@ +package jsonrpc + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "net/http" + "net/url" + "sync/atomic" + + "github.com/go-kit/kit/endpoint" + httptransport "github.com/go-kit/kit/transport/http" +) + +// Client wraps a JSON RPC method and provides a method that implements endpoint.Endpoint. +type Client struct { + client *http.Client + + // JSON RPC endpoint URL + tgt *url.URL + + // JSON RPC method name. + method string + + enc EncodeRequestFunc + dec DecodeResponseFunc + before []httptransport.RequestFunc + after []httptransport.ClientResponseFunc + finalizer httptransport.ClientFinalizerFunc + requestID RequestIDGenerator + bufferedStream bool +} + +type clientRequest struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + ID interface{} `json:"id"` +} + +// NewClient constructs a usable Client for a single remote method. +func NewClient( + tgt *url.URL, + method string, + options ...ClientOption, +) *Client { + c := &Client{ + client: http.DefaultClient, + method: method, + tgt: tgt, + enc: DefaultRequestEncoder, + dec: DefaultResponseDecoder, + before: []httptransport.RequestFunc{}, + after: []httptransport.ClientResponseFunc{}, + requestID: NewAutoIncrementID(0), + bufferedStream: false, + } + for _, option := range options { + option(c) + } + return c +} + +// DefaultRequestEncoder marshals the given request to JSON. +func DefaultRequestEncoder(_ context.Context, req interface{}) (json.RawMessage, error) { + return json.Marshal(req) +} + +// DefaultResponseDecoder unmarshals the given JSON to interface{}. +func DefaultResponseDecoder(_ context.Context, res json.RawMessage) (interface{}, error) { + var result interface{} + err := json.Unmarshal(res, &result) + if err != nil { + return nil, err + } + return result, nil +} + +// ClientOption sets an optional parameter for clients. +type ClientOption func(*Client) + +// SetClient sets the underlying HTTP client used for requests. +// By default, http.DefaultClient is used. +func SetClient(client *http.Client) ClientOption { + return func(c *Client) { c.client = client } +} + +// ClientBefore sets the RequestFuncs that are applied to the outgoing HTTP +// request before it's invoked. +func ClientBefore(before ...httptransport.RequestFunc) ClientOption { + return func(c *Client) { c.before = append(c.before, before...) } +} + +// ClientAfter sets the ClientResponseFuncs applied to the server's HTTP +// response prior to it being decoded. This is useful for obtaining anything +// from the response and adding onto the context prior to decoding. +func ClientAfter(after ...httptransport.ClientResponseFunc) ClientOption { + return func(c *Client) { c.after = append(c.after, after...) } +} + +// ClientFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ClientFinalizer(f httptransport.ClientFinalizerFunc) ClientOption { + return func(c *Client) { c.finalizer = f } +} + +// ClientRequestEncoder sets the func used to encode the request params to JSON. +// If not set, DefaultRequestEncoder is used. +func ClientRequestEncoder(enc EncodeRequestFunc) ClientOption { + return func(c *Client) { c.enc = enc } +} + +// ClientResponseDecoder sets the func used to decode the response params from +// JSON. If not set, DefaultResponseDecoder is used. +func ClientResponseDecoder(dec DecodeResponseFunc) ClientOption { + return func(c *Client) { c.dec = dec } +} + +// RequestIDGenerator returns an ID for the request. +type RequestIDGenerator interface { + Generate() interface{} +} + +// ClientRequestIDGenerator is executed before each request to generate an ID +// for the request. +// By default, AutoIncrementRequestID is used. +func ClientRequestIDGenerator(g RequestIDGenerator) ClientOption { + return func(c *Client) { c.requestID = g } +} + +// BufferedStream sets whether the Response.Body is left open, allowing it +// to be read from later. Useful for transporting a file as a buffered stream. +func BufferedStream(buffered bool) ClientOption { + return func(c *Client) { c.bufferedStream = buffered } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (c Client) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var ( + resp *http.Response + err error + ) + if c.finalizer != nil { + defer func() { + if resp != nil { + ctx = context.WithValue(ctx, httptransport.ContextKeyResponseHeaders, resp.Header) + ctx = context.WithValue(ctx, httptransport.ContextKeyResponseSize, resp.ContentLength) + } + c.finalizer(ctx, err) + }() + } + + var params json.RawMessage + if params, err = c.enc(ctx, request); err != nil { + return nil, err + } + rpcReq := clientRequest{ + JSONRPC: "", + Method: c.method, + Params: params, + ID: c.requestID.Generate(), + } + + req, err := http.NewRequest("POST", c.tgt.String(), nil) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json; charset=utf-8") + var b bytes.Buffer + req.Body = ioutil.NopCloser(&b) + err = json.NewEncoder(&b).Encode(rpcReq) + if err != nil { + return nil, err + } + + for _, f := range c.before { + ctx = f(ctx, req) + } + + resp, err = c.client.Do(req.WithContext(ctx)) + if err != nil { + return nil, err + } + + if !c.bufferedStream { + defer resp.Body.Close() + } + + // Decode the body into an object + var rpcRes Response + err = json.NewDecoder(resp.Body).Decode(&rpcRes) + if err != nil { + return nil, err + } + + for _, f := range c.after { + ctx = f(ctx, resp) + } + + return c.dec(ctx, rpcRes.Result) + } +} + +// ClientFinalizerFunc can be used to perform work at the end of a client HTTP +// request, after the response is returned. The principal +// intended use is for error logging. Additional response parameters are +// provided in the context under keys with the ContextKeyResponse prefix. +// Note: err may be nil. There maybe also no additional response parameters +// depending on when an error occurs. +type ClientFinalizerFunc func(ctx context.Context, err error) + +// autoIncrementID is a RequestIDGenerator that generates +// auto-incrementing integer IDs. +type autoIncrementID struct { + v *uint64 +} + +// NewAutoIncrementID returns an auto-incrementing request ID generator, +// initialised with the given value. +func NewAutoIncrementID(init uint64) RequestIDGenerator { + // Offset by one so that the first generated value = init. + v := init - 1 + return &autoIncrementID{v: &v} +} + +// Generate satisfies RequestIDGenerator +func (i *autoIncrementID) Generate() interface{} { + id := atomic.AddUint64(i.v, 1) + return id +} diff --git a/transport/http/jsonrpc/client_test.go b/transport/http/jsonrpc/client_test.go new file mode 100644 index 000000000..fe42f239d --- /dev/null +++ b/transport/http/jsonrpc/client_test.go @@ -0,0 +1,226 @@ +package jsonrpc_test + +import ( + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/go-kit/kit/transport/http/jsonrpc" +) + +type TestResponse struct { + Body io.ReadCloser + String string +} + +func TestCanCallBeforeFunc(t *testing.T) { + called := false + u, _ := url.Parse("http://senseye.io/jsonrpc") + sut := jsonrpc.NewClient( + u, + "add", + jsonrpc.ClientBefore(func(ctx context.Context, req *http.Request) context.Context { + called = true + return ctx + }), + ) + + sut.Endpoint()(context.TODO(), "foo") + + if !called { + t.Fatal("Expected client before func to be called. Wasn't.") + } +} + +type staticIDGenerator int + +func (g staticIDGenerator) Generate() interface{} { return g } + +func TestClientHappyPath(t *testing.T) { + var ( + afterCalledKey = "AC" + beforeHeaderKey = "BF" + beforeHeaderValue = "beforeFuncWozEre" + testbody = `{"jsonrpc":"2.0", "result":5}` + requestBody []byte + beforeFunc = func(ctx context.Context, r *http.Request) context.Context { + r.Header.Add(beforeHeaderKey, beforeHeaderValue) + return ctx + } + encode = func(ctx context.Context, req interface{}) (json.RawMessage, error) { + return json.Marshal(req) + } + afterFunc = func(ctx context.Context, r *http.Response) context.Context { + return context.WithValue(ctx, afterCalledKey, true) + } + finalizerCalled = false + fin = func(ctx context.Context, err error) { + finalizerCalled = true + } + decode = func(ctx context.Context, res json.RawMessage) (interface{}, error) { + if ac := ctx.Value(afterCalledKey); ac == nil { + t.Fatal("after not called") + } + var result int + err := json.Unmarshal(res, &result) + if err != nil { + return nil, err + } + return result, nil + } + + wantID = 666 + gen = staticIDGenerator(wantID) + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get(beforeHeaderKey) != beforeHeaderValue { + t.Fatal("Header not set by before func.") + } + + b, err := ioutil.ReadAll(r.Body) + if err != nil && err != io.EOF { + t.Fatal(err) + } + requestBody = b + + w.WriteHeader(http.StatusOK) + w.Write([]byte(testbody)) + })) + + sut := jsonrpc.NewClient( + mustParse(server.URL), + "add", + jsonrpc.ClientRequestEncoder(encode), + jsonrpc.ClientResponseDecoder(decode), + jsonrpc.ClientBefore(beforeFunc), + jsonrpc.ClientAfter(afterFunc), + jsonrpc.ClientRequestIDGenerator(gen), + jsonrpc.ClientFinalizer(fin), + jsonrpc.SetClient(http.DefaultClient), + jsonrpc.BufferedStream(false), + ) + + type addRequest struct { + A int + B int + } + + in := addRequest{2, 2} + + result, err := sut.Endpoint()(context.Background(), in) + if err != nil { + t.Fatal(err) + } + ri, ok := result.(int) + if !ok { + t.Fatalf("result is not int: (%T)%+v", result, result) + } + if ri != 5 { + t.Fatalf("want=5, got=%d", ri) + } + + var requestAtServer jsonrpc.Request + err = json.Unmarshal(requestBody, &requestAtServer) + if err != nil { + t.Fatal(err) + } + if id, _ := requestAtServer.ID.Int(); id != wantID { + t.Fatalf("Request ID at server: want=%d, got=%d", wantID, id) + } + + var paramsAtServer addRequest + err = json.Unmarshal(requestAtServer.Params, ¶msAtServer) + if err != nil { + t.Fatal(err) + } + + if paramsAtServer != in { + t.Fatalf("want=%+v, got=%+v", in, paramsAtServer) + } + + if !finalizerCalled { + t.Fatal("Expected finalizer to be called. Wasn't.") + } +} + +func TestCanUseDefaults(t *testing.T) { + var ( + testbody = `{"jsonrpc":"2.0", "result":"boogaloo"}` + requestBody []byte + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + if err != nil && err != io.EOF { + t.Fatal(err) + } + requestBody = b + + w.WriteHeader(http.StatusOK) + w.Write([]byte(testbody)) + })) + + sut := jsonrpc.NewClient( + mustParse(server.URL), + "add", + ) + + type addRequest struct { + A int + B int + } + + in := addRequest{2, 2} + + result, err := sut.Endpoint()(context.Background(), in) + if err != nil { + t.Fatal(err) + } + rs, ok := result.(string) + if !ok { + t.Fatalf("result is not string: (%T)%+v", result, result) + } + if rs != "boogaloo" { + t.Fatalf("want=boogaloo, got=%s", rs) + } + + var requestAtServer jsonrpc.Request + err = json.Unmarshal(requestBody, &requestAtServer) + if err != nil { + t.Fatal(err) + } + var paramsAtServer addRequest + err = json.Unmarshal(requestAtServer.Params, ¶msAtServer) + if err != nil { + t.Fatal(err) + } + + if paramsAtServer != in { + t.Fatalf("want=%+v, got=%+v", in, paramsAtServer) + } +} + +func TestDefaultAutoIncrementer(t *testing.T) { + sut := jsonrpc.NewAutoIncrementID(0) + var want uint64 + for ; want < 100; want++ { + got := sut.Generate() + if got != want { + t.Fatalf("want=%d, got=%d", want, got) + } + } +} + +func mustParse(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} diff --git a/transport/http/jsonrpc/doc.go b/transport/http/jsonrpc/doc.go new file mode 100644 index 000000000..0e2bd52a1 --- /dev/null +++ b/transport/http/jsonrpc/doc.go @@ -0,0 +1,3 @@ +// Package jsonrpc provides a JSON RPC (v2.0) binding for endpoints. +// See http://www.jsonrpc.org/specification +package jsonrpc diff --git a/transport/http/jsonrpc/encode_decode.go b/transport/http/jsonrpc/encode_decode.go new file mode 100644 index 000000000..ab7612e5b --- /dev/null +++ b/transport/http/jsonrpc/encode_decode.go @@ -0,0 +1,48 @@ +package jsonrpc + +import ( + "encoding/json" + + "github.com/go-kit/kit/endpoint" + + "context" +) + +// Server-Side Codec + +// EndpointCodec defines a server Endpoint and its associated codecs +type EndpointCodec struct { + Endpoint endpoint.Endpoint + Decode DecodeRequestFunc + Encode EncodeResponseFunc +} + +// EndpointCodecMap maps the Request.Method to the proper EndpointCodec +type EndpointCodecMap map[string]EndpointCodec + +// DecodeRequestFunc extracts a user-domain request object from an raw JSON +// It's designed to be used in HTTP servers, for server-side endpoints. +// One straightforward DecodeRequestFunc could be something that unmarshals +// JSON from the request body to the concrete request type. +type DecodeRequestFunc func(context.Context, json.RawMessage) (request interface{}, err error) + +// EncodeResponseFunc encodes the passed response object to a JSON RPC response. +// It's designed to be used in HTTP servers, for server-side endpoints. +// One straightforward EncodeResponseFunc could be something that JSON encodes +// the object directly. +type EncodeResponseFunc func(context.Context, interface{}) (response json.RawMessage, err error) + +// Client-Side Codec + +// EncodeRequestFunc encodes the passed request object to raw JSON. +// It's designed to be used in JSON RPC clients, for client-side +// endpoints. One straightforward EncodeResponseFunc could be something that +// JSON encodes the object directly. +type EncodeRequestFunc func(context.Context, interface{}) (request json.RawMessage, err error) + +// DecodeResponseFunc extracts a user-domain response object from an HTTP +// request object. It's designed to be used in JSON RPC clients, for +// client-side endpoints. One straightforward DecodeRequestFunc could be +// something that JSON decodes from the request body to the concrete +// response type. +type DecodeResponseFunc func(context.Context, json.RawMessage) (response interface{}, err error) diff --git a/transport/http/jsonrpc/error.go b/transport/http/jsonrpc/error.go new file mode 100644 index 000000000..f3b9e3a3e --- /dev/null +++ b/transport/http/jsonrpc/error.go @@ -0,0 +1,100 @@ +package jsonrpc + +// Error defines a JSON RPC error that can be returned +// in a Response from the spec +// http://www.jsonrpc.org/specification#error_object +type Error struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// Error implements error. +func (e Error) Error() string { + if e.Message != "" { + return e.Message + } + return errorMessage[e.Code] +} + +// ErrorCode returns the JSON RPC error code associated with the error. +func (e Error) ErrorCode() int { + return e.Code +} + +const ( + // ParseError defines invalid JSON was received by the server. + // An error occurred on the server while parsing the JSON text. + ParseError int = -32700 + + // InvalidRequestError defines the JSON sent is not a valid Request object. + InvalidRequestError int = -32600 + + // MethodNotFoundError defines the method does not exist / is not available. + MethodNotFoundError int = -32601 + + // InvalidParamsError defines invalid method parameter(s). + InvalidParamsError int = -32602 + + // InternalError defines a server error + InternalError int = -32603 +) + +var errorMessage = map[int]string{ + ParseError: "An error occurred on the server while parsing the JSON text.", + InvalidRequestError: "The JSON sent is not a valid Request object.", + MethodNotFoundError: "The method does not exist / is not available.", + InvalidParamsError: "Invalid method parameter(s).", + InternalError: "Internal JSON-RPC error.", +} + +// ErrorMessage returns a message for the JSON RPC error code. It returns the empty +// string if the code is unknown. +func ErrorMessage(code int) string { + return errorMessage[code] +} + +type parseError string + +func (e parseError) Error() string { + return string(e) +} +func (e parseError) ErrorCode() int { + return ParseError +} + +type invalidRequestError string + +func (e invalidRequestError) Error() string { + return string(e) +} +func (e invalidRequestError) ErrorCode() int { + return InvalidRequestError +} + +type methodNotFoundError string + +func (e methodNotFoundError) Error() string { + return string(e) +} +func (e methodNotFoundError) ErrorCode() int { + return MethodNotFoundError +} + +type invalidParamsError string + +func (e invalidParamsError) Error() string { + return string(e) +} +func (e invalidParamsError) ErrorCode() int { + return InvalidParamsError +} + +type internalError string + +func (e internalError) Error() string { + return string(e) +} +func (e internalError) ErrorCode() int { + return InternalError +} diff --git a/transport/http/jsonrpc/error_test.go b/transport/http/jsonrpc/error_test.go new file mode 100644 index 000000000..02f8bcf29 --- /dev/null +++ b/transport/http/jsonrpc/error_test.go @@ -0,0 +1,54 @@ +package jsonrpc + +import "testing" + +func TestError(t *testing.T) { + wantCode := ParseError + sut := Error{ + Code: wantCode, + } + + gotCode := sut.ErrorCode() + if gotCode != wantCode { + t.Fatalf("want=%d, got=%d", gotCode, wantCode) + } + + if sut.Error() == "" { + t.Fatal("Empty error string.") + } + + want := "override" + sut.Message = want + got := sut.Error() + if sut.Error() != want { + t.Fatalf("overridden error message: want=%s, got=%s", want, got) + } + +} +func TestErrorsSatisfyError(t *testing.T) { + errs := []interface{}{ + parseError("parseError"), + invalidRequestError("invalidRequestError"), + methodNotFoundError("methodNotFoundError"), + invalidParamsError("invalidParamsError"), + internalError("internalError"), + } + for _, e := range errs { + err, ok := e.(error) + if !ok { + t.Fatalf("Couldn't assert %s as error.", e) + } + errString := err.Error() + if errString == "" { + t.Fatal("Empty error string") + } + + ec, ok := e.(ErrorCoder) + if !ok { + t.Fatalf("Couldn't assert %s as ErrorCoder.", e) + } + if ErrorMessage(ec.ErrorCode()) == "" { + t.Fatalf("Error type %s returned code of %d, which does not map to error string", e, ec.ErrorCode()) + } + } +} diff --git a/transport/http/jsonrpc/request_response_types.go b/transport/http/jsonrpc/request_response_types.go new file mode 100644 index 000000000..8ea7ddc38 --- /dev/null +++ b/transport/http/jsonrpc/request_response_types.go @@ -0,0 +1,70 @@ +package jsonrpc + +import "encoding/json" + +// Request defines a JSON RPC request from the spec +// http://www.jsonrpc.org/specification#request_object +type Request struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params json.RawMessage `json:"params"` + ID *RequestID `json:"id"` +} + +// RequestID defines a request ID that can be string, number, or null. +// An identifier established by the Client that MUST contain a String, +// Number, or NULL value if included. +// If it is not included it is assumed to be a notification. +// The value SHOULD normally not be Null and +// Numbers SHOULD NOT contain fractional parts. +type RequestID struct { + intValue int + intError error + floatValue float32 + floatError error + stringValue string + stringError error +} + +// UnmarshalJSON satisfies json.Unmarshaler +func (id *RequestID) UnmarshalJSON(b []byte) error { + id.intError = json.Unmarshal(b, &id.intValue) + id.floatError = json.Unmarshal(b, &id.floatValue) + id.stringError = json.Unmarshal(b, &id.stringValue) + + return nil +} + +// Int returns the ID as an integer value. +// An error is returned if the ID can't be treated as an int. +func (id *RequestID) Int() (int, error) { + return id.intValue, id.intError +} + +// Float32 returns the ID as a float value. +// An error is returned if the ID can't be treated as an float. +func (id *RequestID) Float32() (float32, error) { + return id.floatValue, id.floatError +} + +// String returns the ID as a string value. +// An error is returned if the ID can't be treated as an string. +func (id *RequestID) String() (string, error) { + return id.stringValue, id.stringError +} + +// Response defines a JSON RPC response from the spec +// http://www.jsonrpc.org/specification#response_object +type Response struct { + JSONRPC string `json:"jsonrpc"` + Result json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitemty"` +} + +const ( + // Version defines the version of the JSON RPC implementation + Version string = "2.0" + + // ContentType defines the content type to be served. + ContentType string = "application/json; charset=utf-8" +) diff --git a/transport/http/jsonrpc/request_response_types_test.go b/transport/http/jsonrpc/request_response_types_test.go new file mode 100644 index 000000000..4f4abf3d6 --- /dev/null +++ b/transport/http/jsonrpc/request_response_types_test.go @@ -0,0 +1,111 @@ +package jsonrpc_test + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/go-kit/kit/transport/http/jsonrpc" +) + +func TestCanUnMarshalID(t *testing.T) { + cases := []struct { + JSON string + expType string + expValue interface{} + }{ + {`12345`, "int", 12345}, + {`12345.6`, "float", 12345.6}, + {`"stringaling"`, "string", "stringaling"}, + } + + for _, c := range cases { + r := jsonrpc.Request{} + JSON := fmt.Sprintf(`{"id":%s}`, c.JSON) + + var foo interface{} + _ = json.Unmarshal([]byte(JSON), &foo) + + err := json.Unmarshal([]byte(JSON), &r) + if err != nil { + t.Fatalf("Unexpected error unmarshaling JSON into request: %s\n", err) + } + id := r.ID + + switch c.expType { + case "int": + want := c.expValue.(int) + got, err := id.Int() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("'%s' Int(): want %d, got %d.", c.JSON, want, got) + } + + // Allow an int ID to be interpreted as a float. + wantf := float32(c.expValue.(int)) + gotf, err := id.Float32() + if err != nil { + t.Fatal(err) + } + if gotf != wantf { + t.Fatalf("'%s' Int value as Float32(): want %f, got %f.", c.JSON, wantf, gotf) + } + + _, err = id.String() + if err == nil { + t.Fatal("Expected String() to error for int value. Didn't.") + } + case "string": + want := c.expValue.(string) + got, err := id.String() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("'%s' String(): want %s, got %s.", c.JSON, want, got) + } + + _, err = id.Int() + if err == nil { + t.Fatal("Expected Int() to error for string value. Didn't.") + } + _, err = id.Float32() + if err == nil { + t.Fatal("Expected Float32() to error for string value. Didn't.") + } + case "float32": + want := c.expValue.(float32) + got, err := id.Float32() + if err != nil { + t.Fatal(err) + } + if got != want { + t.Fatalf("'%s' Float32(): want %f, got %f.", c.JSON, want, got) + } + + _, err = id.String() + if err == nil { + t.Fatal("Expected String() to error for float value. Didn't.") + } + _, err = id.Int() + if err == nil { + t.Fatal("Expected Int() to error for float value. Didn't.") + } + } + } +} + +func TestCanUnmarshalNullID(t *testing.T) { + r := jsonrpc.Request{} + JSON := `{"id":null}` + err := json.Unmarshal([]byte(JSON), &r) + if err != nil { + t.Fatalf("Unexpected error unmarshaling JSON into request: %s\n", err) + } + + if r.ID != nil { + t.Fatalf("Expected ID to be nil, got %+v.\n", r.ID) + } +} diff --git a/transport/http/jsonrpc/server.go b/transport/http/jsonrpc/server.go new file mode 100644 index 000000000..1b49fe7b5 --- /dev/null +++ b/transport/http/jsonrpc/server.go @@ -0,0 +1,206 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/go-kit/kit/log" + httptransport "github.com/go-kit/kit/transport/http" +) + +// Server wraps an endpoint and implements http.Handler. +type Server struct { + ecm EndpointCodecMap + before []httptransport.RequestFunc + after []httptransport.ServerResponseFunc + errorEncoder httptransport.ErrorEncoder + finalizer httptransport.ServerFinalizerFunc + logger log.Logger +} + +// NewServer constructs a new server, which implements http.Server. +func NewServer( + ecm EndpointCodecMap, + options ...ServerOption, +) *Server { + s := &Server{ + ecm: ecm, + errorEncoder: DefaultErrorEncoder, + logger: log.NewNopLogger(), + } + for _, option := range options { + option(s) + } + return s +} + +// ServerOption sets an optional parameter for servers. +type ServerOption func(*Server) + +// ServerBefore functions are executed on the HTTP request object before the +// request is decoded. +func ServerBefore(before ...httptransport.RequestFunc) ServerOption { + return func(s *Server) { s.before = append(s.before, before...) } +} + +// ServerAfter functions are executed on the HTTP response writer after the +// endpoint is invoked, but before anything is written to the client. +func ServerAfter(after ...httptransport.ServerResponseFunc) ServerOption { + return func(s *Server) { s.after = append(s.after, after...) } +} + +// ServerErrorEncoder is used to encode errors to the http.ResponseWriter +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting and response codes. By default, +// errors will be written with the DefaultErrorEncoder. +func ServerErrorEncoder(ee httptransport.ErrorEncoder) ServerOption { + return func(s *Server) { s.errorEncoder = ee } +} + +// ServerErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ServerErrorEncoder or ServerFinalizer, both of which have access to +// the context. +func ServerErrorLogger(logger log.Logger) ServerOption { + return func(s *Server) { s.logger = logger } +} + +// ServerFinalizer is executed at the end of every HTTP request. +// By default, no finalizer is registered. +func ServerFinalizer(f httptransport.ServerFinalizerFunc) ServerOption { + return func(s *Server) { s.finalizer = f } +} + +// ServeHTTP implements http.Handler. +func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + _, _ = io.WriteString(w, "405 must POST\n") + return + } + ctx := r.Context() + + if s.finalizer != nil { + iw := &interceptingWriter{w, http.StatusOK} + defer func() { s.finalizer(ctx, iw.code, r) }() + w = iw + } + + for _, f := range s.before { + ctx = f(ctx, r) + } + + // Decode the body into an object + var req Request + err := json.NewDecoder(r.Body).Decode(&req) + if err != nil { + rpcerr := parseError("JSON could not be decoded: " + err.Error()) + s.logger.Log("err", rpcerr) + s.errorEncoder(ctx, rpcerr, w) + return + } + + // Get the endpoint and codecs from the map using the method + // defined in the JSON object + ecm, ok := s.ecm[req.Method] + if !ok { + err := methodNotFoundError(fmt.Sprintf("Method %s was not found.", req.Method)) + s.logger.Log("err", err) + s.errorEncoder(ctx, err, w) + return + } + + // Decode the JSON "params" + reqParams, err := ecm.Decode(ctx, req.Params) + if err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, w) + return + } + + // Call the Endpoint with the params + response, err := ecm.Endpoint(ctx, reqParams) + if err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, w) + return + } + + for _, f := range s.after { + ctx = f(ctx, w) + } + + res := Response{ + JSONRPC: Version, + } + + // Encode the response from the Endpoint + resParams, err := ecm.Encode(ctx, response) + if err != nil { + s.logger.Log("err", err) + s.errorEncoder(ctx, err, w) + return + } + + res.Result = resParams + + w.Header().Set("Content-Type", ContentType) + _ = json.NewEncoder(w).Encode(res) +} + +// DefaultErrorEncoder writes the error to the ResponseWriter, +// as a json-rpc error response, with an InternalError status code. +// The Error() string of the error will be used as the response error message. +// If the error implements ErrorCoder, the provided code will be set on the +// response error. +// If the error implements Headerer, the given headers will be set. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { + w.Header().Set("Content-Type", ContentType) + if headerer, ok := err.(httptransport.Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + + e := Error{ + Code: InternalError, + Message: err.Error(), + } + if sc, ok := err.(ErrorCoder); ok { + e.Code = sc.ErrorCode() + } + + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(Response{ + JSONRPC: Version, + Error: &e, + }) +} + +// ErrorCoder is checked by DefaultErrorEncoder. If an error value implements +// ErrorCoder, the integer result of ErrorCode() will be used as the JSONRPC +// error code when encoding the error. +// +// By default, InternalError (-32603) is used. +type ErrorCoder interface { + ErrorCode() int +} + +// interceptingWriter intercepts calls to WriteHeader, so that a finalizer +// can be given the correct status code. +type interceptingWriter struct { + http.ResponseWriter + code int +} + +// WriteHeader may not be explicitly called, so care must be taken to +// initialize w.code to its default value of http.StatusOK. +func (w *interceptingWriter) WriteHeader(code int) { + w.code = code + w.ResponseWriter.WriteHeader(code) +} diff --git a/transport/http/jsonrpc/server_test.go b/transport/http/jsonrpc/server_test.go new file mode 100644 index 000000000..d7960fe05 --- /dev/null +++ b/transport/http/jsonrpc/server_test.go @@ -0,0 +1,335 @@ +package jsonrpc_test + +import ( + "context" + "encoding/json" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/transport/http/jsonrpc" +) + +func addBody() io.Reader { + return body(`{"jsonrpc": "2.0", "method": "add", "params": [3, 2], "id": 1}`) +} + +func body(in string) io.Reader { + return strings.NewReader(in) +} + +func expectErrorCode(t *testing.T, want int, body []byte) { + var r jsonrpc.Response + err := json.Unmarshal(body, &r) + if err != nil { + t.Fatalf("Cant' decode response. err=%s, body=%s", err, body) + } + if r.Error == nil { + t.Fatalf("Expected error on response. Got none: %s", body) + } + if have := r.Error.Code; want != have { + t.Fatalf("Unexpected error code. Want %d, have %d: %s", want, have, body) + } +} + +func nopDecoder(context.Context, json.RawMessage) (interface{}, error) { return struct{}{}, nil } +func nopEncoder(context.Context, interface{}) (json.RawMessage, error) { return []byte("[]"), nil } + +type mockLogger struct { + Called bool + LastArgs []interface{} +} + +func (l *mockLogger) Log(keyvals ...interface{}) error { + l.Called = true + l.LastArgs = append(l.LastArgs, keyvals) + return nil +} + +func TestServerBadDecode(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint.Nop, + Decode: func(context.Context, json.RawMessage) (interface{}, error) { return struct{}{}, errors.New("oof") }, + Encode: nopEncoder, + }, + } + logger := mockLogger{} + handler := jsonrpc.NewServer(ecm, jsonrpc.ServerErrorLogger(&logger)) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", addBody()) + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d: %s", want, have, buf) + } + expectErrorCode(t, jsonrpc.InternalError, buf) + if !logger.Called { + t.Fatal("Expected logger to be called with error. Wasn't.") + } +} + +func TestServerBadEndpoint(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("oof") }, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler := jsonrpc.NewServer(ecm) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", addBody()) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + expectErrorCode(t, jsonrpc.InternalError, buf) +} + +func TestServerBadEncode(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint.Nop, + Decode: nopDecoder, + Encode: func(context.Context, interface{}) (json.RawMessage, error) { return []byte{}, errors.New("oof") }, + }, + } + handler := jsonrpc.NewServer(ecm) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", addBody()) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + expectErrorCode(t, jsonrpc.InternalError, buf) +} + +func TestServerErrorEncoder(t *testing.T) { + errTeapot := errors.New("teapot") + code := func(err error) int { + if err == errTeapot { + return http.StatusTeapot + } + return http.StatusInternalServerError + } + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot }, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler := jsonrpc.NewServer( + ecm, + jsonrpc.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }), + ) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", addBody()) + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestCanRejectNonPostRequest(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{} + handler := jsonrpc.NewServer(ecm) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Get(server.URL) + if want, have := http.StatusMethodNotAllowed, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } +} + +func TestCanRejectInvalidJSON(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{} + handler := jsonrpc.NewServer(ecm) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", body("clearlynotjson")) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + expectErrorCode(t, jsonrpc.ParseError, buf) +} + +func TestServerUnregisteredMethod(t *testing.T) { + ecm := jsonrpc.EndpointCodecMap{} + handler := jsonrpc.NewServer(ecm) + server := httptest.NewServer(handler) + defer server.Close() + resp, _ := http.Post(server.URL, "application/json", addBody()) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + expectErrorCode(t, jsonrpc.MethodNotFoundError, buf) +} + +func TestServerHappyPath(t *testing.T) { + step, response := testServer(t) + step() + resp := <-response + defer resp.Body.Close() // nolint + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := http.StatusOK, resp.StatusCode; want != have { + t.Errorf("want %d, have %d (%s)", want, have, buf) + } + var r jsonrpc.Response + err := json.Unmarshal(buf, &r) + if err != nil { + t.Fatalf("Cant' decode response. err=%s, body=%s", err, buf) + } + if r.JSONRPC != jsonrpc.Version { + t.Fatalf("JSONRPC Version: want=%s, got=%s", jsonrpc.Version, r.JSONRPC) + } + if r.Error != nil { + t.Fatalf("Unxpected error on response: %s", buf) + } +} + +func TestMultipleServerBefore(t *testing.T) { + var done = make(chan struct{}) + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint.Nop, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler := jsonrpc.NewServer( + ecm, + jsonrpc.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + jsonrpc.ServerBefore(func(ctx context.Context, r *http.Request) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerBefores are used") + } + + close(done) + return ctx + }), + ) + server := httptest.NewServer(handler) + defer server.Close() + http.Post(server.URL, "application/json", addBody()) // nolint + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestMultipleServerAfter(t *testing.T) { + var done = make(chan struct{}) + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint.Nop, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler := jsonrpc.NewServer( + ecm, + jsonrpc.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { + ctx = context.WithValue(ctx, "one", 1) + + return ctx + }), + jsonrpc.ServerAfter(func(ctx context.Context, w http.ResponseWriter) context.Context { + if _, ok := ctx.Value("one").(int); !ok { + t.Error("Value was not set properly when multiple ServerAfters are used") + } + + close(done) + return ctx + }), + ) + server := httptest.NewServer(handler) + defer server.Close() + http.Post(server.URL, "application/json", addBody()) // nolint + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } +} + +func TestCanFinalize(t *testing.T) { + var done = make(chan struct{}) + var finalizerCalled bool + ecm := jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint.Nop, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler := jsonrpc.NewServer( + ecm, + jsonrpc.ServerFinalizer(func(ctx context.Context, code int, req *http.Request) { + finalizerCalled = true + close(done) + }), + ) + server := httptest.NewServer(handler) + defer server.Close() + http.Post(server.URL, "application/json", addBody()) // nolint + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for finalizer") + } + + if !finalizerCalled { + t.Fatal("Finalizer was not called.") + } +} + +func testServer(t *testing.T) (step func(), resp <-chan *http.Response) { + var ( + stepch = make(chan bool) + endpoint = func(ctx context.Context, request interface{}) (response interface{}, err error) { + <-stepch + return struct{}{}, nil + } + response = make(chan *http.Response) + ecm = jsonrpc.EndpointCodecMap{ + "add": jsonrpc.EndpointCodec{ + Endpoint: endpoint, + Decode: nopDecoder, + Encode: nopEncoder, + }, + } + handler = jsonrpc.NewServer(ecm) + ) + go func() { + server := httptest.NewServer(handler) + defer server.Close() + rb := strings.NewReader(`{"jsonrpc": "2.0", "method": "add", "params": [3, 2], "id": 1}`) + resp, err := http.Post(server.URL, "application/json", rb) + if err != nil { + t.Error(err) + return + } + response <- resp + }() + return func() { stepch <- true }, response +}