Skip to content

Commit

Permalink
Fix error on writeless handle (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maksim authored Apr 3, 2022
1 parent a3e798f commit de47dc9
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
13 changes: 13 additions & 0 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc ...func(r *http.Re
headers: ww.Header(),
status: ww.Status(),
body: buf.Bytes(),

// the handler may not write header and body in some logic,
// while writing only the body, an attempt is made to write the default header (http.StatusOK)
skip: ww.IsHeaderWrong(),
}
return val, nil
})
Expand All @@ -121,6 +125,10 @@ func HandlerWithKey(cacheSize int, ttl time.Duration, keyFunc ...func(r *http.Re
panic("stampede: handler received unexpected response value type")
}

if resp.skip {
return
}

header := w.Header()

nextHeader:
Expand Down Expand Up @@ -149,6 +157,7 @@ type responseValue struct {
headers http.Header
status int
body []byte
skip bool
}

type responseWriter struct {
Expand All @@ -167,6 +176,10 @@ func (b *responseWriter) WriteHeader(code int) {
}
}

func (b *responseWriter) IsHeaderWrong() bool {
return !b.wroteHeader && (b.code < 100 || b.code > 999)
}

func (b *responseWriter) Write(buf []byte) (int, error) {
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
Expand Down
36 changes: 36 additions & 0 deletions stampede_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,39 @@ func TestIssue6_BypassCORSHeaders(t *testing.T) {
// expect to have only one actual hit
assert.Equal(t, uint64(1), count)
}

func TestPanic(t *testing.T) {
mux := http.NewServeMux()
middleware := stampede.Handler(100, 1*time.Hour)
mux.Handle("/", middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
t.Log(r.Method, r.URL)
})))

ts := httptest.NewServer(mux)
defer ts.Close()

{
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
t.Log(resp.StatusCode)
}
{
req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
if err != nil {
t.Fatal(err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
t.Log(resp.StatusCode)
}
}

0 comments on commit de47dc9

Please sign in to comment.