diff --git a/httpcache.go b/httpcache.go index f6a2ec4..ece06e2 100644 --- a/httpcache.go +++ b/httpcache.go @@ -521,7 +521,8 @@ type cachingReadCloser struct { // Underlying ReadCloser. R io.ReadCloser // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) + OnEOF func(io.Reader) + eofOnce sync.Once buf bytes.Buffer // buf stores a copy of the content of R. } @@ -534,12 +535,21 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) { n, err = r.R.Read(p) r.buf.Write(p[:n]) if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) + r.eofOnce.Do(func() { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + }) } return n, err } +var dummyBuf = make([]byte, 1) + func (r *cachingReadCloser) Close() error { + r.eofOnce.Do(func() { + if n, err := r.R.Read(dummyBuf); n == 0 && err == io.EOF { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + }) return r.R.Close() } diff --git a/httpcache_test.go b/httpcache_test.go index a504641..069410e 100644 --- a/httpcache_test.go +++ b/httpcache_test.go @@ -2,6 +2,8 @@ package httpcache import ( "bytes" + "compress/gzip" + "encoding/json" "errors" "flag" "io" @@ -10,6 +12,7 @@ import ( "net/http/httptest" "os" "strconv" + "strings" "testing" "time" ) @@ -56,6 +59,30 @@ func setup() { w.Write([]byte(r.Method)) })) + var compressedJsonCounter = 0 + mux.HandleFunc("/compressedJson", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("accept-encoding"), "gzip") { + w.WriteHeader(http.StatusBadRequest) + return + } + + compressedJsonCounter++ + w.Header().Set("X-Counter", strconv.Itoa(compressedJsonCounter)) + + etag := "124567" + w.Header().Set("etag", etag) + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + + w.Header().Set("Content-Encoding", "gzip") + + gzw := gzip.NewWriter(w) + defer gzw.Close() + gzw.Write([]byte(`{"some": "json"}`)) + })) + mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { lm := "Fri, 14 Dec 2010 01:01:50 GMT" if r.Header.Get("if-modified-since") == lm { @@ -365,10 +392,97 @@ func TestDontStorePartialRangeInCache(t *testing.T) { } } +func TestRevalidateCompressedJSONResponses(t *testing.T) { + type some struct{ Some string } + readJsonResponse := func(rc io.ReadCloser) (some, error) { + defer rc.Close() + var got some + err := json.NewDecoder(rc).Decode(&got) + return got, err + } + { + req, err := http.NewRequest("GET", s.server.URL+"/compressedJson", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + got, err := readJsonResponse(resp.Body) + if err != nil { + t.Fatal(err) + } + want := some{"json"} + if got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Error("XFromCache header isn't blank") + } + if resp.Header.Get("x-counter") != "1" { + t.Error("X-Counter header is not 1") + } + if resp.Header.Get("etag") == "" { + t.Error("ETag is blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/compressedJson", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + _, err = readJsonResponse(resp.Body) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "1" { + t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get("x-counter") != "2" { + t.Error("X-Counter header is not 2") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/compressedJson", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + _, err = readJsonResponse(resp.Body) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "1" { + t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.Header.Get("x-counter") != "2" { + t.Error("X-Counter was not updated on revalidation") + } + } +} + func TestCacheOnlyIfBodyRead(t *testing.T) { resetTest() { - req, err := http.NewRequest("GET", s.server.URL, nil) + req, err := http.NewRequest("GET", s.server.URL+"/method", nil) if err != nil { t.Fatal(err) } @@ -383,7 +497,7 @@ func TestCacheOnlyIfBodyRead(t *testing.T) { resp.Body.Close() } { - req, err := http.NewRequest("GET", s.server.URL, nil) + req, err := http.NewRequest("GET", s.server.URL+"/method", nil) if err != nil { t.Fatal(err) }