Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion httpcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package httpcontext
import (
"io"
"net/http"
"reflect"
)

// Set sets a context value on req.
Expand Down Expand Up @@ -79,7 +80,7 @@ func (crc *contextReadCloser) Context() map[interface{}]interface{} {
}

func getContextReadCloser(req *http.Request) ContextReadCloser {
crc, ok := req.Body.(ContextReadCloser)
crc, ok := findContextReadCloser(req.Body)
if !ok {
crc = &contextReadCloser{
ReadCloser: req.Body,
Expand All @@ -89,3 +90,51 @@ func getContextReadCloser(req *http.Request) ContextReadCloser {
}
return crc
}

func findContextReadCloser(rc io.ReadCloser) (ContextReadCloser, bool) {
for {
if a, ok := rc.(ContextReadCloser); ok {
return a, true
}
rc = findNestedReadCloser(rc)
if rc == nil {
return nil, false
}
}
}

func findNestedReadCloser(rc io.ReadCloser) io.ReadCloser {
if s := findStruct(rc); s != nil {
// try a struct field called ReadCloser first
if maybeRC := (*s).FieldByName("ReadCloser"); (maybeRC != reflect.Value{}) && maybeRC.CanInterface() {
if rc, ok := maybeRC.Interface().(io.ReadCloser); ok {
return rc
}
}

// try all fields and see if we can find a ReadCloser
for i := 0; i < (*s).Type().NumField(); i++ {
if maybeRC := (*s).Field(i); maybeRC.CanInterface() {
if rc, ok := maybeRC.Interface().(io.ReadCloser); ok {
return rc
}
}
}
}
return nil
}

func findStruct(rc io.ReadCloser) *reflect.Value {
maybeStruct := reflect.ValueOf(rc)
for {
switch maybeStruct.Kind() {
case reflect.Struct:
return &maybeStruct
case reflect.Ptr:
maybeStruct = reflect.Indirect(maybeStruct)
default:
return nil
}
}
}