diff --git a/transport/http/example_test.go b/transport/http/example_test.go new file mode 100644 index 000000000..311b6c4c8 --- /dev/null +++ b/transport/http/example_test.go @@ -0,0 +1,38 @@ +package http + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "golang.org/x/net/context" +) + +func ExamplePopulateRequestContext() { + handler := NewServer( + context.Background(), + func(ctx context.Context, request interface{}) (response interface{}, err error) { + fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string)) + fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string)) + fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string)) + fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string)) + return struct{}{}, nil + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ServerBefore(PopulateRequestContext), + ) + + server := httptest.NewServer(handler) + defer server.Close() + + req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil) + req.Header.Set("X-Request-Id", "a1b2c3d4e5") + http.DefaultClient.Do(req) + + // Output: + // Method PATCH + // RequestPath /search + // RequestURI /search?q=sympatico + // X-Request-ID a1b2c3d4e5 +} diff --git a/transport/http/request_response_funcs.go b/transport/http/request_response_funcs.go index 1a3ef9b23..88c6e7b66 100644 --- a/transport/http/request_response_funcs.go +++ b/transport/http/request_response_funcs.go @@ -22,13 +22,13 @@ type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Conte // clients, after a request has been made, but prior to it being decoded. type ClientResponseFunc func(context.Context, *http.Response) context.Context -// SetContentType returns a ResponseFunc that sets the Content-Type header to -// the provided value. +// SetContentType returns a ServerResponseFunc that sets the Content-Type header +// to the provided value. func SetContentType(contentType string) ServerResponseFunc { return SetResponseHeader("Content-Type", contentType) } -// SetResponseHeader returns a ResponseFunc that sets the specified header. +// SetResponseHeader returns a ServerResponseFunc that sets the given header. func SetResponseHeader(key, val string) ServerResponseFunc { return func(ctx context.Context, w http.ResponseWriter) context.Context { w.Header().Set(key, val) @@ -36,10 +36,85 @@ func SetResponseHeader(key, val string) ServerResponseFunc { } } -// SetRequestHeader returns a RequestFunc that sets the specified header. +// SetRequestHeader returns a RequestFunc that sets the given header. func SetRequestHeader(key, val string) RequestFunc { return func(ctx context.Context, r *http.Request) context.Context { r.Header.Set(key, val) return ctx } } + +// PopulateRequestContext is a RequestFunc that populates several values into +// the context from the HTTP request. Those values may be extracted using the +// corresponding ContextKey type in this package. +func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context { + for k, v := range map[contextKey]string{ + ContextKeyRequestMethod: r.Method, + ContextKeyRequestURI: r.RequestURI, + ContextKeyRequestPath: r.URL.Path, + ContextKeyRequestProto: r.Proto, + ContextKeyRequestHost: r.Host, + ContextKeyRequestRemoteAddr: r.RemoteAddr, + ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"), + ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"), + ContextKeyRequestAuthorization: r.Header.Get("Authorization"), + ContextKeyRequestReferer: r.Header.Get("Referer"), + ContextKeyRequestUserAgent: r.Header.Get("User-Agent"), + ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"), + } { + ctx = context.WithValue(ctx, k, v) + } + return ctx +} + +type contextKey int + +const ( + // ContextKeyRequestMethod is populated in the context by + // PopulateRequestContext. Its value is r.Method. + ContextKeyRequestMethod contextKey = iota + + // ContextKeyRequestURI is populated in the context by + // PopulateRequestContext. Its value is r.RequestURI. + ContextKeyRequestURI + + // ContextKeyRequestPath is populated in the context by + // PopulateRequestContext. Its value is r.URL.Path. + ContextKeyRequestPath + + // ContextKeyRequestProto is populated in the context by + // PopulateRequestContext. Its value is r.Proto. + ContextKeyRequestProto + + // ContextKeyRequestHost is populated in the context by + // PopulateRequestContext. Its value is r.Host. + ContextKeyRequestHost + + // ContextKeyRequestRemoteAddr is populated in the context by + // PopulateRequestContext. Its value is r.RemoteAddr. + ContextKeyRequestRemoteAddr + + // ContextKeyRequestXForwardedFor is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For"). + ContextKeyRequestXForwardedFor + + // ContextKeyRequestXForwardedProto is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto"). + ContextKeyRequestXForwardedProto + + // ContextKeyRequestAuthorization is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Authorization"). + ContextKeyRequestAuthorization + + // ContextKeyRequestReferer is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Referer"). + ContextKeyRequestReferer + + // ContextKeyRequestUserAgent is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("User-Agent"). + ContextKeyRequestUserAgent + + // ContextKeyRequestXRequestID is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). + ContextKeyRequestXRequestID +) diff --git a/transport/http/server.go b/transport/http/server.go index 351c19880..e909e6b07 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -1,6 +1,7 @@ package http import ( + "encoding/json" "net/http" "golang.org/x/net/context" @@ -36,7 +37,7 @@ func NewServer( e: e, dec: dec, enc: enc, - errorEncoder: defaultErrorEncoder, + errorEncoder: DefaultErrorEncoder, logger: log.NewNopLogger(), } for _, option := range options { @@ -63,8 +64,7 @@ func ServerAfter(after ...ServerResponseFunc) ServerOption { // ServerErrorEncoder is used to encode errors to the http.ResponseWriter // whenever they're encountered in the processing of a request. Clients can // use this to provide custom error formatting and response codes. By default, -// errors will be written as plain text with an appropriate, if generic, -// status code. +// errors will be written with the DefaultErrorEncoder. func ServerErrorEncoder(ee ErrorEncoder) ServerOption { return func(s *Server) { s.errorEncoder = ee } } @@ -134,8 +134,66 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) // intended use is for request logging. type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) -func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { - http.Error(w, err.Error(), http.StatusInternalServerError) +// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a +// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as +// a sensible default. If the response implements Headerer, the provided headers +// will be applied to the response. If the response implements StatusCoder, the +// provided StatusCode will be used instead of 200. +func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if headerer, ok := response.(Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + code := http.StatusOK + if sc, ok := response.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + return json.NewEncoder(w).Encode(response) +} + +// DefaultErrorEncoder writes the error to the ResponseWriter, by default a +// content type of text/plain, a body of the plain text of the error, and a +// status code of 500. If the error implements Headerer, the provided headers +// will be applied to the response. If the error implements json.Marshaler, and +// the marshaling succeeds, a content type of application/json and the JSON +// encoded form of the error will be used. If the error implements StatusCoder, +// the provided StatusCode will be used instead of 500. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { + contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) + if marshaler, ok := err.(json.Marshaler); ok { + if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { + contentType, body = "application/json; charset=utf-8", jsonBody + } + } + w.Header().Set("Content-Type", contentType) + if headerer, ok := err.(Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + code := http.StatusInternalServerError + if sc, ok := err.(StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + w.Write(body) +} + +// StatusCoder is checked by DefaultErrorEncoder. If an error value implements +// StatusCoder, the StatusCode will be used when encoding the error. By default, +// StatusInternalServerError (500) is used. +type StatusCoder interface { + StatusCode() int +} + +// Headerer is checked by DefaultErrorEncoder. If an error value implements +// Headerer, the provided headers will be applied to the response writer, after +// the Content-Type is set. +type Headerer interface { + Headers() http.Header } type interceptingWriter struct { diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 992d04733..8ea025f65 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "golang.org/x/net/context" @@ -122,6 +123,75 @@ func TestServerFinalizer(t *testing.T) { } } +type enhancedResponse struct { + Foo string `json:"foo"` +} + +func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } +func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } + +func TestEncodeJSONResponse(t *testing.T) { + handler := httptransport.NewServer( + context.Background(), + func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { + t.Errorf("X-Edward: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + +type enhancedError struct{} + +func (e enhancedError) Error() string { return "enhanced error" } +func (e enhancedError) StatusCode() int { return http.StatusTeapot } +func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } +func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } + +func TestEnhancedError(t *testing.T) { + handler := httptransport.NewServer( + context.Background(), + func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { + t.Errorf("X-Enhanced: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) { var ( ctx, cancelfn = context.WithCancel(context.Background())