diff --git a/httpcontext.go b/httpcontext.go index 4751b55..80dddaa 100644 --- a/httpcontext.go +++ b/httpcontext.go @@ -3,6 +3,7 @@ package httpcontext import ( "io" "net/http" + "reflect" ) // Set sets a context value on req. @@ -79,7 +80,7 @@ func (crc *contextReadCloser) Context() map[interface{}]interface{} { } func getContextReadCloser(req *http.Request) ContextReadCloser { - crc, ok := req.Body.(ContextReadCloser) + crc, ok := findContextReadCloser(req.Body) if !ok { crc = &contextReadCloser{ ReadCloser: req.Body, @@ -89,3 +90,51 @@ func getContextReadCloser(req *http.Request) ContextReadCloser { } return crc } + +func findContextReadCloser(rc io.ReadCloser) (ContextReadCloser, bool) { + for { + if a, ok := rc.(ContextReadCloser); ok { + return a, true + } + rc = findNestedReadCloser(rc) + if rc == nil { + return nil, false + } + } +} + +func findNestedReadCloser(rc io.ReadCloser) io.ReadCloser { + if s := findStruct(rc); s != nil { + // try a struct field called ReadCloser first + if maybeRC := (*s).FieldByName("ReadCloser"); (maybeRC != reflect.Value{}) && maybeRC.CanInterface() { + if rc, ok := maybeRC.Interface().(io.ReadCloser); ok { + return rc + } + } + + // try all fields and see if we can find a ReadCloser + for i := 0; i < (*s).Type().NumField(); i++ { + if maybeRC := (*s).Field(i); maybeRC.CanInterface() { + if rc, ok := maybeRC.Interface().(io.ReadCloser); ok { + return rc + } + } + } + } + return nil +} + +func findStruct(rc io.ReadCloser) *reflect.Value { + maybeStruct := reflect.ValueOf(rc) + for { + switch maybeStruct.Kind() { + case reflect.Struct: + return &maybeStruct + case reflect.Ptr: + maybeStruct = reflect.Indirect(maybeStruct) + default: + return nil + } + } +} +