diff --git a/transport/twirp/client.go b/transport/twirp/client.go new file mode 100644 index 000000000..7ffb7e7ab --- /dev/null +++ b/transport/twirp/client.go @@ -0,0 +1,143 @@ +package twirp + +import ( + "context" + "fmt" + "github.com/go-kit/kit/endpoint" + "github.com/twitchtv/twirp" + "net/http" + "reflect" +) + +// Client wraps a Twirp client and provides a method that implements endpoint.Endpoint. +type Client struct { + client interface{} + method string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []ClientRequestFunc + after []ClientResponseFunc + finalizer ClientFinalizerFunc +} + +// NewClient constructs a usable Client for a single remote method. +func NewClient( + client interface{}, + method string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...ClientOption, +) *Client { + c := &Client{ + client: client, + method: method, + enc: enc, + dec: dec, + before: []ClientRequestFunc{}, + after: []ClientResponseFunc{}, + } + for _, option := range options { + option(c) + } + return c +} + +// ClientOption sets an optional parameter for clients. +type ClientOption func(*Client) + +// ClientBefore sets the ClientRequestFunc that are applied to the outgoing +// request before it's invoked. +func ClientBefore(before ...ClientRequestFunc) ClientOption { + return func(c *Client) { c.before = append(c.before, before...) } +} + +// ClientAfter sets the ClientResponseFuncs applied to the incoming +// request prior to it being decoded. This is useful for obtaining anything off +// of the response and adding onto the context prior to decoding. +func ClientAfter(after ...ClientResponseFunc) ClientOption { + return func(c *Client) { c.after = append(c.after, after...) } +} + +// ClientFinalizer is executed at the end of every request. +// By default, no finalizer is registered. +func ClientFinalizer(f ClientFinalizerFunc) ClientOption { + return func(s *Client) { s.finalizer = f } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (c Client) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var ( + req interface{} + err error + ) + + // Process ClientFinalizers + if c.finalizer != nil { + defer func() { + c.finalizer(ctx, err) + }() + } + + // Encode + req, err = c.enc(ctx, request) + if err != nil { + return nil, err + } + + // Create an empty http.Header to hold the headers that we will accumulate in before functions. + var reqHeader http.Header + // Process ClientRequestFunctions + for _, f := range c.before { + ctx = f(ctx, &reqHeader) + } + + // Tell twirp to use these headers in the request. + ctx, err = twirp.WithHTTPRequestHeaders(ctx, reqHeader) + if err != nil { + return nil, err + } + + client := reflect.ValueOf(&c.client) + method := client.MethodByName(c.method) + if !method.IsValid() { + interfaceName := reflect.TypeOf(&c.client).Elem().Name() + return nil, fmt.Errorf("Invalid method specified: %s does not have method %s", interfaceName, c.method) + } + + args := make([]reflect.Value, 2) + args[0] = reflect.ValueOf(ctx) + args[1] = reflect.ValueOf(req) + + retVals := make([]reflect.Value, 2) + retVals = method.Call(args) + resp := retVals[0].Interface() + err = retVals[1].Interface().(error) + if err != nil { + return nil, err + } + + // Process ClientResponseFunctions + for _, f := range c.after { + ctx = f(ctx) + } + + // Decode + response, err := c.dec(ctx, resp) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// ClientFinalizerFunc can be used to perform work at the end of a client +// request, after the response is returned. The principal +// intended use is for error logging. Note: err may be nil. +// There maybe also no additional response parameters depending on when +// an error occurs. +type ClientFinalizerFunc func(ctx context.Context, err error) diff --git a/transport/twirp/doc.go b/transport/twirp/doc.go new file mode 100644 index 000000000..21388dd04 --- /dev/null +++ b/transport/twirp/doc.go @@ -0,0 +1,2 @@ +// Package twirp provides a general purpose Twirp binding for endpoints. +package twirp diff --git a/transport/twirp/encode_decode.go b/transport/twirp/encode_decode.go new file mode 100644 index 000000000..d3d213166 --- /dev/null +++ b/transport/twirp/encode_decode.go @@ -0,0 +1,29 @@ +package twirp + +import ( + "context" +) + +// DecodeRequestFunc extracts a user-domain request object from a Twirp request. +// It's designed to be used in Twirp servers, for server-side endpoints. One +// straightforward DecodeRequestFunc could be something that decodes from the +// Twirp request message to the concrete request type. +type DecodeRequestFunc func(context.Context, interface{}) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed request object into the Twirp request +// object. It's designed to be used in Twirp clients, for client-side endpoints. +// One straightforward EncodeRequestFunc could something that encodes the object +// directly to the Twirp request message. +type EncodeRequestFunc func(context.Context, interface{}) (request interface{}, err error) + +// EncodeResponseFunc encodes the passed response object to the Twirp response +// message. It's designed to be used in Twirp servers, for server-side endpoints. +// One straightforward EncodeResponseFunc could be something that encodes the +// object directly to the Twirp response message. +type EncodeResponseFunc func(context.Context, interface{}) (response interface{}, err error) + +// DecodeResponseFunc extracts a user-domain response object from a Twirp +// response object. It's designed to be used in Twirp clients, for client-side +// endpoints. One straightforward DecodeResponseFunc could be something that +// decodes from the Twirp response message to the concrete response type. +type DecodeResponseFunc func(context.Context, interface{}) (response interface{}, err error) diff --git a/transport/twirp/request_response_funcs.go b/transport/twirp/request_response_funcs.go new file mode 100644 index 000000000..2c2ebe8f9 --- /dev/null +++ b/transport/twirp/request_response_funcs.go @@ -0,0 +1,31 @@ +package twirp + +import ( + "context" + "net/http" +) + +// ClientRequestFunc may modify the context. ClientRequestFuncs are executed +// after creating the request but prior to sending the Twirp request to +// the server. +type ClientRequestFunc func(context.Context, *http.Header) context.Context + +// ServerRequestFunc may take information from the context. ServerRequestFuncs are +// executed prior to invoking the endpoint. +type ServerRequestFunc func(context.Context, http.Header) context.Context + +// ServerResponseFunc may modify the context. ServerResponseFuncs are only executed in +// servers, after invoking the endpoint but prior to writing a response. +type ServerResponseFunc func(context.Context) context.Context + +// ClientResponseFunc may take information from the context. ClientResponseFuncs are only executed in +// clients, after a request has been made, but prior to it being decoded. +type ClientResponseFunc func(context.Context) context.Context + +// SetRequestHeader returns a RequestFunc that sets the given header. It uses the standard net/http/header Add function and will append the specified value if others already exist. +func SetRequestHeader(key, val string) ClientRequestFunc { + return func(ctx context.Context, header *http.Header) context.Context { + header.Add(key, val) + return ctx + } +} diff --git a/transport/twirp/server.go b/transport/twirp/server.go new file mode 100644 index 000000000..1cfadad8c --- /dev/null +++ b/transport/twirp/server.go @@ -0,0 +1,127 @@ +package twirp + +import ( + "context" + + "errors" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/twitchtv/twirp" + "net/http" +) + +// Handler which should be called from the Twirp binding of the service +// implementation. The incoming request parameter, and returned response +// parameter, are both Twirp types, not user-domain. +type Handler interface { + ServeTwirp(ctx context.Context, request interface{}) (context.Context, interface{}, error) +} + +// Server wraps an endpoint and implements Twirp Handler. +type Server struct { + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + before []ServerRequestFunc + after []ServerResponseFunc + finalizer ServerFinalizerFunc + logger log.Logger +} + +// NewServer constructs a new server, which implements wraps the provided +// endpoint and implements the Handler interface. Consumers should write +// bindings that adapt the concrete Twirp methods from their compiled protobuf +// definitions to individual handlers. Request and response objects are from the +// caller business domain, not Twirp request and reply types. +func NewServer( + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + options ...ServerOption, +) *Server { + s := &Server{ + e: e, + dec: dec, + enc: enc, + logger: log.NewNopLogger(), + } + for _, option := range options { + option(s) + } + return s +} + +// ServerOption sets an optional parameter for servers. +type ServerOption func(*Server) + +// ServerBefore functions are executed on the HTTP request object before the +// request is decoded. +func ServerBefore(before ...ServerRequestFunc) ServerOption { + return func(s *Server) { s.before = append(s.before, before...) } +} + +// ServerAfter functions are executed on the HTTP response writer after the +// endpoint is invoked, but before anything is written to the client. +func ServerAfter(after ...ServerResponseFunc) ServerOption { + return func(s *Server) { s.after = append(s.after, after...) } +} + +// ServerErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. +func ServerErrorLogger(logger log.Logger) ServerOption { + return func(s *Server) { s.logger = logger } +} + +// ServeTwirp implements the Handler interface. +func (s Server) ServeTwirp(ctx context.Context, req interface{}) (context.Context, interface{}, error) { + + // Process ServerFinalizerFunctions + if s.finalizer != nil { + defer func() { + s.finalizer(ctx, req) + }() + } + // Extract the headers from the ctx + var ( + reqHeader http.Header + ok bool + ) + reqHeader, ok = twirp.HTTPRequestHeaders(ctx) + if !ok { + err := errors.New("error extracting http headers from Twirp Context (twirptransport.HTTPRequestHeaders)") + s.logger.Log("err", err) + return ctx, nil, err + } + // Process ServerRequestFunctions + for _, f := range s.before { + ctx = f(ctx, reqHeader) + } + request, err := s.dec(ctx, req) + if err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + + response, err := s.e(ctx, request) + if err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + + // Process ServerResponseFunctions + for _, f := range s.after { + ctx = f(ctx) + } + twirpResp, err := s.enc(ctx, response) + if err != nil { + s.logger.Log("err", err) + return ctx, nil, err + } + + return ctx, twirpResp, nil +} + +// ServerFinalizerFunc can be used to perform work at the end of a +// request, after the response has been written to the client. The principal +// intended use is for request logging. +type ServerFinalizerFunc func(ctx context.Context, req interface{})