From 32d8321d893730b530a5122dacd2ffe2d7206eb8 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Wed, 14 Dec 2016 18:19:23 +0100 Subject: [PATCH 1/3] transport/http: enhance the DefaultErrorEncoder - Errors may implement StatusCoder interface - Errors may implement Headerer interface - Errors may implement json.Marshaler interface --- transport/http/server.go | 52 +++++++++++++++++++++++++++++++---- transport/http/server_test.go | 35 +++++++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) diff --git a/transport/http/server.go b/transport/http/server.go index 351c19880..2077db856 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -1,6 +1,7 @@ package http import ( + "encoding/json" "net/http" "golang.org/x/net/context" @@ -36,7 +37,7 @@ func NewServer( e: e, dec: dec, enc: enc, - errorEncoder: defaultErrorEncoder, + errorEncoder: DefaultErrorEncoder, logger: log.NewNopLogger(), } for _, option := range options { @@ -63,8 +64,7 @@ func ServerAfter(after ...ServerResponseFunc) ServerOption { // ServerErrorEncoder is used to encode errors to the http.ResponseWriter // whenever they're encountered in the processing of a request. Clients can // use this to provide custom error formatting and response codes. By default, -// errors will be written as plain text with an appropriate, if generic, -// status code. +// errors will be written with the DefaultErrorEncoder. func ServerErrorEncoder(ee ErrorEncoder) ServerOption { return func(s *Server) { s.errorEncoder = ee } } @@ -134,8 +134,50 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) // intended use is for request logging. type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) -func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { - http.Error(w, err.Error(), http.StatusInternalServerError) +// DefaultErrorEncoder writes the error to the ResponseWriter, by default with +// status code 500, content type of text/plain, and the plain text of the error. +// If the error implements StatusCoder, the provided StatusCode will be used +// instead of 500. If the error implements Headerer, the provided headers will +// be applied to the response writer. If the error implements json.Marshaler, +// and the marshaling succeeds, a content type of application/json and the JSON +// encoded form of the error will be used. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { + code := http.StatusInternalServerError + if sc, ok := err.(StatusCoder); ok { + code = sc.StatusCode() + } + + contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) + if marshaler, ok := err.(json.Marshaler); ok { + if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { + contentType, body = "application/json; charset=utf-8", jsonBody + } + } + + w.Header().Set("Content-Type", contentType) + + if headerer, ok := err.(Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + + w.WriteHeader(code) + w.Write(body) +} + +// StatusCoder is checked by DefaultErrorEncoder. If an error value implements +// StatusCoder, the StatusCode will be used when encoding the error. By default, +// StatusInternalServerError (500) is used. +type StatusCoder interface { + StatusCode() int +} + +// Headerer is checked by DefaultErrorEncoder. If an error value implements +// Headerer, the provided headers will be applied to the response writer, after +// the Content-Type is set. +type Headerer interface { + Headers() http.Header } type interceptingWriter struct { diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 992d04733..bb7d41b45 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -122,6 +122,41 @@ func TestServerFinalizer(t *testing.T) { } } +type enhancedError struct{} + +func (e enhancedError) Error() string { return "enhanced error" } +func (e enhancedError) StatusCode() int { return http.StatusTeapot } +func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } +func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } + +func TestServerSpecialError(t *testing.T) { + handler := httptransport.NewServer( + context.Background(), + func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil }, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if want, have := http.StatusTeapot, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "1", resp.Header.Get("X-Enhanced"); want != have { + t.Errorf("X-Enhanced: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"err":"enhanced"}`, string(buf); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) { var ( ctx, cancelfn = context.WithCancel(context.Background()) From 0abb92d4e1d0264d890e9e53b2ca5ee69c88857a Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Wed, 14 Dec 2016 19:04:22 +0100 Subject: [PATCH 2/3] transport/http: provide EncodeJSONResponse With sane defaults. --- transport/http/server.go | 42 ++++++++++++++++++++++++----------- transport/http/server_test.go | 39 ++++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 15 deletions(-) diff --git a/transport/http/server.go b/transport/http/server.go index 2077db856..e909e6b07 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -134,34 +134,50 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter) // intended use is for request logging. type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request) -// DefaultErrorEncoder writes the error to the ResponseWriter, by default with -// status code 500, content type of text/plain, and the plain text of the error. -// If the error implements StatusCoder, the provided StatusCode will be used -// instead of 500. If the error implements Headerer, the provided headers will -// be applied to the response writer. If the error implements json.Marshaler, -// and the marshaling succeeds, a content type of application/json and the JSON -// encoded form of the error will be used. -func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { - code := http.StatusInternalServerError - if sc, ok := err.(StatusCoder); ok { +// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a +// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as +// a sensible default. If the response implements Headerer, the provided headers +// will be applied to the response. If the response implements StatusCoder, the +// provided StatusCode will be used instead of 200. +func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + if headerer, ok := response.(Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + code := http.StatusOK + if sc, ok := response.(StatusCoder); ok { code = sc.StatusCode() } + w.WriteHeader(code) + return json.NewEncoder(w).Encode(response) +} +// DefaultErrorEncoder writes the error to the ResponseWriter, by default a +// content type of text/plain, a body of the plain text of the error, and a +// status code of 500. If the error implements Headerer, the provided headers +// will be applied to the response. If the error implements json.Marshaler, and +// the marshaling succeeds, a content type of application/json and the JSON +// encoded form of the error will be used. If the error implements StatusCoder, +// the provided StatusCode will be used instead of 500. +func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) { contentType, body := "text/plain; charset=utf-8", []byte(err.Error()) if marshaler, ok := err.(json.Marshaler); ok { if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil { contentType, body = "application/json; charset=utf-8", jsonBody } } - w.Header().Set("Content-Type", contentType) - if headerer, ok := err.(Headerer); ok { for k := range headerer.Headers() { w.Header().Set(k, headerer.Headers().Get(k)) } } - + code := http.StatusInternalServerError + if sc, ok := err.(StatusCoder); ok { + code = sc.StatusCode() + } w.WriteHeader(code) w.Write(body) } diff --git a/transport/http/server_test.go b/transport/http/server_test.go index bb7d41b45..8ea025f65 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -5,6 +5,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "golang.org/x/net/context" @@ -122,6 +123,40 @@ func TestServerFinalizer(t *testing.T) { } } +type enhancedResponse struct { + Foo string `json:"foo"` +} + +func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired } +func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} } + +func TestEncodeJSONResponse(t *testing.T) { + handler := httptransport.NewServer( + context.Background(), + func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have { + t.Errorf("X-Edward: want %q, have %q", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have { + t.Errorf("Body: want %s, have %s", want, have) + } +} + type enhancedError struct{} func (e enhancedError) Error() string { return "enhanced error" } @@ -129,7 +164,7 @@ func (e enhancedError) StatusCode() int { return http.StatusTeapot func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil } func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} } -func TestServerSpecialError(t *testing.T) { +func TestEnhancedError(t *testing.T) { handler := httptransport.NewServer( context.Background(), func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} }, @@ -152,7 +187,7 @@ func TestServerSpecialError(t *testing.T) { t.Errorf("X-Enhanced: want %q, have %q", want, have) } buf, _ := ioutil.ReadAll(resp.Body) - if want, have := `{"err":"enhanced"}`, string(buf); want != have { + if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have { t.Errorf("Body: want %s, have %s", want, have) } } From 8599541790b633f2c71526be2a106ccc6f075eb4 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Sun, 18 Dec 2016 15:27:36 +0100 Subject: [PATCH 3/3] transport/http: PopulateRequestContext This RequestFunc moves many HTTP request parameters to the context under specific keys. If wired in to the transport, those values are available to subsequent endpoints or service methods that receive the context. This can be used to e.g. log transport (HTTP) details in a service logging middleware. --- transport/http/example_test.go | 38 +++++++++++ transport/http/request_response_funcs.go | 83 ++++++++++++++++++++++-- 2 files changed, 117 insertions(+), 4 deletions(-) create mode 100644 transport/http/example_test.go diff --git a/transport/http/example_test.go b/transport/http/example_test.go new file mode 100644 index 000000000..311b6c4c8 --- /dev/null +++ b/transport/http/example_test.go @@ -0,0 +1,38 @@ +package http + +import ( + "fmt" + "net/http" + "net/http/httptest" + + "golang.org/x/net/context" +) + +func ExamplePopulateRequestContext() { + handler := NewServer( + context.Background(), + func(ctx context.Context, request interface{}) (response interface{}, err error) { + fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string)) + fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string)) + fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string)) + fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string)) + return struct{}{}, nil + }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, http.ResponseWriter, interface{}) error { return nil }, + ServerBefore(PopulateRequestContext), + ) + + server := httptest.NewServer(handler) + defer server.Close() + + req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil) + req.Header.Set("X-Request-Id", "a1b2c3d4e5") + http.DefaultClient.Do(req) + + // Output: + // Method PATCH + // RequestPath /search + // RequestURI /search?q=sympatico + // X-Request-ID a1b2c3d4e5 +} diff --git a/transport/http/request_response_funcs.go b/transport/http/request_response_funcs.go index 1a3ef9b23..88c6e7b66 100644 --- a/transport/http/request_response_funcs.go +++ b/transport/http/request_response_funcs.go @@ -22,13 +22,13 @@ type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Conte // clients, after a request has been made, but prior to it being decoded. type ClientResponseFunc func(context.Context, *http.Response) context.Context -// SetContentType returns a ResponseFunc that sets the Content-Type header to -// the provided value. +// SetContentType returns a ServerResponseFunc that sets the Content-Type header +// to the provided value. func SetContentType(contentType string) ServerResponseFunc { return SetResponseHeader("Content-Type", contentType) } -// SetResponseHeader returns a ResponseFunc that sets the specified header. +// SetResponseHeader returns a ServerResponseFunc that sets the given header. func SetResponseHeader(key, val string) ServerResponseFunc { return func(ctx context.Context, w http.ResponseWriter) context.Context { w.Header().Set(key, val) @@ -36,10 +36,85 @@ func SetResponseHeader(key, val string) ServerResponseFunc { } } -// SetRequestHeader returns a RequestFunc that sets the specified header. +// SetRequestHeader returns a RequestFunc that sets the given header. func SetRequestHeader(key, val string) RequestFunc { return func(ctx context.Context, r *http.Request) context.Context { r.Header.Set(key, val) return ctx } } + +// PopulateRequestContext is a RequestFunc that populates several values into +// the context from the HTTP request. Those values may be extracted using the +// corresponding ContextKey type in this package. +func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context { + for k, v := range map[contextKey]string{ + ContextKeyRequestMethod: r.Method, + ContextKeyRequestURI: r.RequestURI, + ContextKeyRequestPath: r.URL.Path, + ContextKeyRequestProto: r.Proto, + ContextKeyRequestHost: r.Host, + ContextKeyRequestRemoteAddr: r.RemoteAddr, + ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"), + ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"), + ContextKeyRequestAuthorization: r.Header.Get("Authorization"), + ContextKeyRequestReferer: r.Header.Get("Referer"), + ContextKeyRequestUserAgent: r.Header.Get("User-Agent"), + ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"), + } { + ctx = context.WithValue(ctx, k, v) + } + return ctx +} + +type contextKey int + +const ( + // ContextKeyRequestMethod is populated in the context by + // PopulateRequestContext. Its value is r.Method. + ContextKeyRequestMethod contextKey = iota + + // ContextKeyRequestURI is populated in the context by + // PopulateRequestContext. Its value is r.RequestURI. + ContextKeyRequestURI + + // ContextKeyRequestPath is populated in the context by + // PopulateRequestContext. Its value is r.URL.Path. + ContextKeyRequestPath + + // ContextKeyRequestProto is populated in the context by + // PopulateRequestContext. Its value is r.Proto. + ContextKeyRequestProto + + // ContextKeyRequestHost is populated in the context by + // PopulateRequestContext. Its value is r.Host. + ContextKeyRequestHost + + // ContextKeyRequestRemoteAddr is populated in the context by + // PopulateRequestContext. Its value is r.RemoteAddr. + ContextKeyRequestRemoteAddr + + // ContextKeyRequestXForwardedFor is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For"). + ContextKeyRequestXForwardedFor + + // ContextKeyRequestXForwardedProto is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto"). + ContextKeyRequestXForwardedProto + + // ContextKeyRequestAuthorization is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Authorization"). + ContextKeyRequestAuthorization + + // ContextKeyRequestReferer is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("Referer"). + ContextKeyRequestReferer + + // ContextKeyRequestUserAgent is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("User-Agent"). + ContextKeyRequestUserAgent + + // ContextKeyRequestXRequestID is populated in the context by + // PopulateRequestContext. Its value is r.Header.Get("X-Request-Id"). + ContextKeyRequestXRequestID +)