Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions transport/http/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package http

import (
"fmt"
"net/http"
"net/http/httptest"

"golang.org/x/net/context"
)

func ExamplePopulateRequestContext() {
handler := NewServer(
context.Background(),
func(ctx context.Context, request interface{}) (response interface{}, err error) {
fmt.Println("Method", ctx.Value(ContextKeyRequestMethod).(string))
fmt.Println("RequestPath", ctx.Value(ContextKeyRequestPath).(string))
fmt.Println("RequestURI", ctx.Value(ContextKeyRequestURI).(string))
fmt.Println("X-Request-ID", ctx.Value(ContextKeyRequestXRequestID).(string))
return struct{}{}, nil
},
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
ServerBefore(PopulateRequestContext),
)

server := httptest.NewServer(handler)
defer server.Close()

req, _ := http.NewRequest("PATCH", fmt.Sprintf("%s/search?q=sympatico", server.URL), nil)
req.Header.Set("X-Request-Id", "a1b2c3d4e5")
http.DefaultClient.Do(req)

// Output:
// Method PATCH
// RequestPath /search
// RequestURI /search?q=sympatico
// X-Request-ID a1b2c3d4e5
}
83 changes: 79 additions & 4 deletions transport/http/request_response_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,99 @@ type ServerResponseFunc func(context.Context, http.ResponseWriter) context.Conte
// clients, after a request has been made, but prior to it being decoded.
type ClientResponseFunc func(context.Context, *http.Response) context.Context

// SetContentType returns a ResponseFunc that sets the Content-Type header to
// the provided value.
// SetContentType returns a ServerResponseFunc that sets the Content-Type header
// to the provided value.
func SetContentType(contentType string) ServerResponseFunc {
return SetResponseHeader("Content-Type", contentType)
}

// SetResponseHeader returns a ResponseFunc that sets the specified header.
// SetResponseHeader returns a ServerResponseFunc that sets the given header.
func SetResponseHeader(key, val string) ServerResponseFunc {
return func(ctx context.Context, w http.ResponseWriter) context.Context {
w.Header().Set(key, val)
return ctx
}
}

// SetRequestHeader returns a RequestFunc that sets the specified header.
// SetRequestHeader returns a RequestFunc that sets the given header.
func SetRequestHeader(key, val string) RequestFunc {
return func(ctx context.Context, r *http.Request) context.Context {
r.Header.Set(key, val)
return ctx
}
}

// PopulateRequestContext is a RequestFunc that populates several values into
// the context from the HTTP request. Those values may be extracted using the
// corresponding ContextKey type in this package.
func PopulateRequestContext(ctx context.Context, r *http.Request) context.Context {
for k, v := range map[contextKey]string{
ContextKeyRequestMethod: r.Method,
ContextKeyRequestURI: r.RequestURI,
ContextKeyRequestPath: r.URL.Path,
ContextKeyRequestProto: r.Proto,
ContextKeyRequestHost: r.Host,
ContextKeyRequestRemoteAddr: r.RemoteAddr,
ContextKeyRequestXForwardedFor: r.Header.Get("X-Forwarded-For"),
ContextKeyRequestXForwardedProto: r.Header.Get("X-Forwarded-Proto"),
ContextKeyRequestAuthorization: r.Header.Get("Authorization"),
ContextKeyRequestReferer: r.Header.Get("Referer"),
ContextKeyRequestUserAgent: r.Header.Get("User-Agent"),
ContextKeyRequestXRequestID: r.Header.Get("X-Request-Id"),
} {
ctx = context.WithValue(ctx, k, v)
}
return ctx
}

type contextKey int

const (
// ContextKeyRequestMethod is populated in the context by
// PopulateRequestContext. Its value is r.Method.
ContextKeyRequestMethod contextKey = iota

// ContextKeyRequestURI is populated in the context by
// PopulateRequestContext. Its value is r.RequestURI.
ContextKeyRequestURI

// ContextKeyRequestPath is populated in the context by
// PopulateRequestContext. Its value is r.URL.Path.
ContextKeyRequestPath

// ContextKeyRequestProto is populated in the context by
// PopulateRequestContext. Its value is r.Proto.
ContextKeyRequestProto

// ContextKeyRequestHost is populated in the context by
// PopulateRequestContext. Its value is r.Host.
ContextKeyRequestHost

// ContextKeyRequestRemoteAddr is populated in the context by
// PopulateRequestContext. Its value is r.RemoteAddr.
ContextKeyRequestRemoteAddr

// ContextKeyRequestXForwardedFor is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-For").
ContextKeyRequestXForwardedFor

// ContextKeyRequestXForwardedProto is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Forwarded-Proto").
ContextKeyRequestXForwardedProto

// ContextKeyRequestAuthorization is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("Authorization").
ContextKeyRequestAuthorization

// ContextKeyRequestReferer is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("Referer").
ContextKeyRequestReferer

// ContextKeyRequestUserAgent is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("User-Agent").
ContextKeyRequestUserAgent

// ContextKeyRequestXRequestID is populated in the context by
// PopulateRequestContext. Its value is r.Header.Get("X-Request-Id").
ContextKeyRequestXRequestID
)
68 changes: 63 additions & 5 deletions transport/http/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"encoding/json"
"net/http"

"golang.org/x/net/context"
Expand Down Expand Up @@ -36,7 +37,7 @@ func NewServer(
e: e,
dec: dec,
enc: enc,
errorEncoder: defaultErrorEncoder,
errorEncoder: DefaultErrorEncoder,
logger: log.NewNopLogger(),
}
for _, option := range options {
Expand All @@ -63,8 +64,7 @@ func ServerAfter(after ...ServerResponseFunc) ServerOption {
// ServerErrorEncoder is used to encode errors to the http.ResponseWriter
// whenever they're encountered in the processing of a request. Clients can
// use this to provide custom error formatting and response codes. By default,
// errors will be written as plain text with an appropriate, if generic,
// status code.
// errors will be written with the DefaultErrorEncoder.
func ServerErrorEncoder(ee ErrorEncoder) ServerOption {
return func(s *Server) { s.errorEncoder = ee }
}
Expand Down Expand Up @@ -134,8 +134,66 @@ type ErrorEncoder func(ctx context.Context, err error, w http.ResponseWriter)
// intended use is for request logging.
type ServerFinalizerFunc func(ctx context.Context, code int, r *http.Request)

func defaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
http.Error(w, err.Error(), http.StatusInternalServerError)
// EncodeJSONResponse is a EncodeResponseFunc that serializes the response as a
// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as
// a sensible default. If the response implements Headerer, the provided headers
// will be applied to the response. If the response implements StatusCoder, the
// provided StatusCode will be used instead of 200.
func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := response.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
}
}
code := http.StatusOK
if sc, ok := response.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
return json.NewEncoder(w).Encode(response)
}

// DefaultErrorEncoder writes the error to the ResponseWriter, by default a
// content type of text/plain, a body of the plain text of the error, and a
// status code of 500. If the error implements Headerer, the provided headers
// will be applied to the response. If the error implements json.Marshaler, and
// the marshaling succeeds, a content type of application/json and the JSON
// encoded form of the error will be used. If the error implements StatusCoder,
// the provided StatusCode will be used instead of 500.
func DefaultErrorEncoder(_ context.Context, err error, w http.ResponseWriter) {
contentType, body := "text/plain; charset=utf-8", []byte(err.Error())
if marshaler, ok := err.(json.Marshaler); ok {
if jsonBody, marshalErr := marshaler.MarshalJSON(); marshalErr == nil {
contentType, body = "application/json; charset=utf-8", jsonBody
}
}
w.Header().Set("Content-Type", contentType)
if headerer, ok := err.(Headerer); ok {
for k := range headerer.Headers() {
w.Header().Set(k, headerer.Headers().Get(k))
}
}
code := http.StatusInternalServerError
if sc, ok := err.(StatusCoder); ok {
code = sc.StatusCode()
}
w.WriteHeader(code)
w.Write(body)
}

// StatusCoder is checked by DefaultErrorEncoder. If an error value implements
// StatusCoder, the StatusCode will be used when encoding the error. By default,
// StatusInternalServerError (500) is used.
type StatusCoder interface {
StatusCode() int
}

// Headerer is checked by DefaultErrorEncoder. If an error value implements
// Headerer, the provided headers will be applied to the response writer, after
// the Content-Type is set.
type Headerer interface {
Headers() http.Header
}

type interceptingWriter struct {
Expand Down
70 changes: 70 additions & 0 deletions transport/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"

"golang.org/x/net/context"
Expand Down Expand Up @@ -122,6 +123,75 @@ func TestServerFinalizer(t *testing.T) {
}
}

type enhancedResponse struct {
Foo string `json:"foo"`
}

func (e enhancedResponse) StatusCode() int { return http.StatusPaymentRequired }
func (e enhancedResponse) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }

func TestEncodeJSONResponse(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return enhancedResponse{Foo: "bar"}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)

server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if want, have := http.StatusPaymentRequired, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
if want, have := "Snowden", resp.Header.Get("X-Edward"); want != have {
t.Errorf("X-Edward: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"foo":"bar"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}

type enhancedError struct{}

func (e enhancedError) Error() string { return "enhanced error" }
func (e enhancedError) StatusCode() int { return http.StatusTeapot }
func (e enhancedError) MarshalJSON() ([]byte, error) { return []byte(`{"err":"enhanced"}`), nil }
func (e enhancedError) Headers() http.Header { return http.Header{"X-Enhanced": []string{"1"}} }

func TestEnhancedError(t *testing.T) {
handler := httptransport.NewServer(
context.Background(),
func(context.Context, interface{}) (interface{}, error) { return nil, enhancedError{} },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(_ context.Context, w http.ResponseWriter, _ interface{}) error { return nil },
)

server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if want, have := http.StatusTeapot, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
if want, have := "1", resp.Header.Get("X-Enhanced"); want != have {
t.Errorf("X-Enhanced: want %q, have %q", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := `{"err":"enhanced"}`, strings.TrimSpace(string(buf)); want != have {
t.Errorf("Body: want %s, have %s", want, have)
}
}

func testServer(t *testing.T) (cancel, step func(), resp <-chan *http.Response) {
var (
ctx, cancelfn = context.WithCancel(context.Background())
Expand Down