Skip to content
This repository has been archived by the owner on Apr 20, 2023. It is now read-only.

Commit

Permalink
Merge pull request #52 from gregjones/dont-store-uncacheable-range-re…
Browse files Browse the repository at this point in the history
…sponses

Fix invalid caching of uncacheable range requests.
  • Loading branch information
gregjones committed May 24, 2016
2 parents 37c2ad6 + 4082b08 commit 16db777
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 4 deletions.
8 changes: 4 additions & 4 deletions httpcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ func (t *Transport) setModReq(orig, mod *http.Request) {
// will be returned.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
cacheKey := cacheKey(req)
cacheableMethod := req.Method == "GET" || req.Method == "HEAD"
cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
var cachedResp *http.Response
if cacheableMethod {
if cacheable {
cachedResp, err = CachedResponse(t.Cache, req)
} else {
// Need to invalidate an existing value
Expand All @@ -194,7 +194,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
transport = http.DefaultTransport
}

if cachedResp != nil && err == nil && cacheableMethod && req.Header.Get("range") == "" {
if cacheable && cachedResp != nil && err == nil {
if t.MarkCachedResponses {
cachedResp.Header.Set(XFromCache, "1")
}
Expand Down Expand Up @@ -281,7 +281,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
}

if cacheableMethod && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
varyKey = http.CanonicalHeaderKey(varyKey)
fakeHeader := "X-Varied-" + varyKey
Expand Down
127 changes: 127 additions & 0 deletions httpcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ func setup() {
w.Write([]byte(r.Method))
}))

mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
lm := "Fri, 14 Dec 2010 01:01:50 GMT"
if r.Header.Get("if-modified-since") == lm {
w.WriteHeader(http.StatusNotModified)
return
}
w.Header().Set("last-modified", lm)
if r.Header.Get("range") == "bytes=4-9" {
w.WriteHeader(http.StatusPartialContent)
w.Write([]byte(" text "))
return
}
w.Write([]byte("Some text content"))
}))

mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store")
}))
Expand Down Expand Up @@ -184,6 +199,118 @@ func TestCacheableMethod(t *testing.T) {
}
}

func TestDontStorePartialRangeInCache(t *testing.T) {
resetTest()
{
req, err := http.NewRequest("GET", s.server.URL+"/range", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("range", "bytes=4-9")
resp, err := s.client.Do(req)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
t.Fatal(err)
}
err = resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if got, want := buf.String(), " text "; got != want {
t.Errorf("got %q, want %q", got, want)
}
if resp.StatusCode != http.StatusPartialContent {
t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode)
}
}
{
req, err := http.NewRequest("GET", s.server.URL+"/range", nil)
if err != nil {
t.Fatal(err)
}
resp, err := s.client.Do(req)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
t.Fatal(err)
}
err = resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if got, want := buf.String(), "Some text content"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode)
}
if resp.Header.Get(XFromCache) != "" {
t.Error("XFromCache header isn't blank")
}
}
{
req, err := http.NewRequest("GET", s.server.URL+"/range", nil)
if err != nil {
t.Fatal(err)
}
resp, err := s.client.Do(req)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
t.Fatal(err)
}
err = resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if got, want := buf.String(), "Some text content"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if resp.StatusCode != http.StatusOK {
t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode)
}
if resp.Header.Get(XFromCache) != "1" {
t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache))
}
}
{
req, err := http.NewRequest("GET", s.server.URL+"/range", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("range", "bytes=4-9")
resp, err := s.client.Do(req)
if err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
t.Fatal(err)
}
err = resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if got, want := buf.String(), " text "; got != want {
t.Errorf("got %q, want %q", got, want)
}
if resp.StatusCode != http.StatusPartialContent {
t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode)
}
}
}

func TestGetOnlyIfCachedHit(t *testing.T) {
resetTest()
{
Expand Down

0 comments on commit 16db777

Please sign in to comment.