Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@ sudo: false

language: go
go:
- 1.5.4
- 1.6.2
- 1.7.5
- tip
50 changes: 6 additions & 44 deletions route/route.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
package route

import (
"fmt"
"net/http"
"sync"

"github.com/julienschmidt/httprouter"
"golang.org/x/net/context"
)

var (
mtx = sync.RWMutex{}
ctxts = map[*http.Request]context.Context{}
)

// Context returns the context for the request.
func Context(r *http.Request) context.Context {
mtx.RLock()
defer mtx.RUnlock()
return ctxts[r]
}

type param string

// Param returns param p for the context.
Expand All @@ -33,59 +19,35 @@ func WithParam(ctx context.Context, p, v string) context.Context {
return context.WithValue(ctx, param(p), v)
}

// ContextFunc returns a new context for a request.
type ContextFunc func(r *http.Request) (context.Context, error)

// Router wraps httprouter.Router and adds support for prefixed sub-routers
// and per-request context injections.
type Router struct {
rtr *httprouter.Router
prefix string
ctxFn ContextFunc
}

// New returns a new Router.
func New(ctxFn ContextFunc) *Router {
if ctxFn == nil {
ctxFn = func(r *http.Request) (context.Context, error) {
return context.Background(), nil
}
}
func New() *Router {
return &Router{
rtr: httprouter.New(),
ctxFn: ctxFn,
rtr: httprouter.New(),
}
}

// WithPrefix returns a router that prefixes all registered routes with prefix.
func (r *Router) WithPrefix(prefix string) *Router {
return &Router{rtr: r.rtr, prefix: r.prefix + prefix, ctxFn: r.ctxFn}
return &Router{rtr: r.rtr, prefix: r.prefix + prefix}
}

// handle turns a HandlerFunc into an httprouter.Handle.
func (r *Router) handle(h http.HandlerFunc) httprouter.Handle {
return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
reqCtx, err := r.ctxFn(req)
if err != nil {
http.Error(w, fmt.Sprintf("Error creating request context: %v", err), http.StatusBadRequest)
return
}
ctx, cancel := context.WithCancel(reqCtx)
ctx, cancel := context.WithCancel(req.Context())
defer cancel()

for _, p := range params {
ctx = context.WithValue(ctx, param(p.Key), p.Value)
}

mtx.Lock()
ctxts[req] = ctx
mtx.Unlock()

h(w, req)

mtx.Lock()
delete(ctxts, req)
mtx.Unlock()
h(w, req.WithContext(ctx))
}
}

Expand Down Expand Up @@ -132,7 +94,7 @@ func FileServe(dir string) http.HandlerFunc {
fs := http.FileServer(http.Dir(dir))

return func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = Param(Context(r), "filepath")
r.URL.Path = Param(r.Context(), "filepath")
fs.ServeHTTP(w, r)
}
}
45 changes: 7 additions & 38 deletions route/route_test.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
package route

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"golang.org/x/net/context"
)

func TestRedirect(t *testing.T) {
router := New(nil).WithPrefix("/test/prefix")
router := New().WithPrefix("/test/prefix")
w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "http://localhost:9090/foo", nil)
if err != nil {
Expand All @@ -29,47 +26,19 @@ func TestRedirect(t *testing.T) {
}
}

func TestContextFunc(t *testing.T) {
router := New(func(r *http.Request) (context.Context, error) {
return context.WithValue(context.Background(), "testkey", "testvalue"), nil
})

router.Get("/test", func(w http.ResponseWriter, r *http.Request) {
want := "testvalue"
got := Context(r).Value("testkey")
func TestContext(t *testing.T) {
router := New()
router.Get("/test/:foo/", func(w http.ResponseWriter, r *http.Request) {
want := "bar"
got := Param(r.Context(), "foo")
if want != got {
t.Fatalf("Unexpected context value: want %q, got %q", want, got)
}
})

r, err := http.NewRequest("GET", "http://localhost:9090/test", nil)
r, err := http.NewRequest("GET", "http://localhost:9090/test/bar/", nil)
if err != nil {
t.Fatalf("Error building test request: %s", err)
}
router.ServeHTTP(nil, r)
}

func TestContextFnError(t *testing.T) {
router := New(func(r *http.Request) (context.Context, error) {
return context.Background(), fmt.Errorf("test error")
})

router.Get("/test", func(w http.ResponseWriter, r *http.Request) {})

r, err := http.NewRequest("GET", "http://localhost:9090/test", nil)
if err != nil {
t.Fatalf("Error building test request: %s", err)
}
w := httptest.NewRecorder()
router.ServeHTTP(w, r)

if w.Code != http.StatusBadRequest {
t.Fatalf("Unexpected response status: got %q, want %q", w.Code, http.StatusBadRequest)
}

want := "Error creating request context: test error\n"
got := w.Body.String()
if want != got {
t.Fatalf("Unexpected response body: got %q, want %q", got, want)
}
}