Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caddyhttp: Use sendfile(2) by implementing ReadFrom on wrappers #5022

Merged
merged 2 commits into from
Sep 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions modules/caddyhttp/responsewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ func (rww *ResponseWriterWrapper) Push(target string, opts *http.PushOptions) er
return ErrNotImplemented
}

// ReadFrom implements io.ReaderFrom. It simply calls the underlying
// ResponseWriter's ReadFrom method if there is one, otherwise it defaults
// to io.Copy.
func (rww *ResponseWriterWrapper) ReadFrom(r io.Reader) (n int64, err error) {
if rf, ok := rww.ResponseWriter.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
return io.Copy(rww.ResponseWriter, r)
}

// HTTPInterfaces mix all the interfaces that middleware ResponseWriters need to support.
type HTTPInterfaces interface {
http.ResponseWriter
Expand Down Expand Up @@ -188,9 +198,26 @@ func (rr *responseRecorder) Write(data []byte) (int, error) {
} else {
n, err = rr.buf.Write(data)
}
if err == nil {
rr.size += n

rr.size += n
return n, err
}

func (rr *responseRecorder) ReadFrom(r io.Reader) (int64, error) {
rr.WriteHeader(http.StatusOK)
var n int64
var err error
if rr.stream {
if rf, ok := rr.ResponseWriter.(io.ReaderFrom); ok {
n, err = rf.ReadFrom(r)
} else {
n, err = io.Copy(rr.ResponseWriter, r)
}
} else {
n, err = rr.buf.ReadFrom(r)
}

rr.size += int(n)
return n, err
}

Expand Down Expand Up @@ -251,4 +278,10 @@ type ShouldBufferFunc func(status int, header http.Header) bool
var (
_ HTTPInterfaces = (*ResponseWriterWrapper)(nil)
_ ResponseRecorder = (*responseRecorder)(nil)

// Implementing ReaderFrom can be such a significant
// optimization that it should probably be required!
// see PR #5022 (25%-50% speedup)
_ io.ReaderFrom = (*ResponseWriterWrapper)(nil)
_ io.ReaderFrom = (*responseRecorder)(nil)
)
165 changes: 165 additions & 0 deletions modules/caddyhttp/responsewriter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package caddyhttp

import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"testing"
)

type responseWriterSpy interface {
http.ResponseWriter
Written() string
CalledReadFrom() bool
}

var (
_ responseWriterSpy = (*baseRespWriter)(nil)
_ responseWriterSpy = (*readFromRespWriter)(nil)
)

// a barebones http.ResponseWriter mock
type baseRespWriter []byte

func (brw *baseRespWriter) Write(d []byte) (int, error) {
*brw = append(*brw, d...)
return len(d), nil
}
func (brw *baseRespWriter) Header() http.Header { return nil }
func (brw *baseRespWriter) WriteHeader(statusCode int) {}
func (brw *baseRespWriter) Written() string { return string(*brw) }
func (brw *baseRespWriter) CalledReadFrom() bool { return false }

// an http.ResponseWriter mock that supports ReadFrom
type readFromRespWriter struct {
baseRespWriter
called bool
}

func (rf *readFromRespWriter) ReadFrom(r io.Reader) (int64, error) {
rf.called = true
return io.Copy(&rf.baseRespWriter, r)
}

func (rf *readFromRespWriter) CalledReadFrom() bool { return rf.called }

func TestResponseWriterWrapperReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
wantReadFrom bool
}{
"no ReadFrom": {
responseWriter: &baseRespWriter{},
wantReadFrom: false,
},
"has ReadFrom": {
responseWriter: &readFromRespWriter{},
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
// what we expect middlewares to do:
type myWrapper struct {
*ResponseWriterWrapper
}

wrapped := myWrapper{
ResponseWriterWrapper: &ResponseWriterWrapper{ResponseWriter: tt.responseWriter},
}

const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}

fmt.Println(name)
if _, err := io.Copy(wrapped, src); err != nil {
t.Errorf("Copy() err = %v", err)
}

if got := tt.responseWriter.Written(); got != srcData {
t.Errorf("data = %q, want %q", got, srcData)
}

if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}

func TestResponseRecorderReadFrom(t *testing.T) {
tests := map[string]struct {
responseWriter responseWriterSpy
shouldBuffer bool
wantReadFrom bool
}{
"buffered plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed plain": {
responseWriter: &baseRespWriter{},
shouldBuffer: false,
wantReadFrom: false,
},
"buffered ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: true,
wantReadFrom: false,
},
"streamed ReadFrom": {
responseWriter: &readFromRespWriter{},
shouldBuffer: false,
wantReadFrom: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
var buf bytes.Buffer

rr := NewResponseRecorder(tt.responseWriter, &buf, func(status int, header http.Header) bool {
return tt.shouldBuffer
})

const srcData = "boo!"
// hides everything but Read, since strings.Reader implements WriteTo it would
// take precedence over our ReadFrom.
src := struct{ io.Reader }{strings.NewReader(srcData)}

if _, err := io.Copy(rr, src); err != nil {
t.Errorf("Copy() err = %v", err)
}

wantStreamed := srcData
wantBuffered := ""
if tt.shouldBuffer {
wantStreamed = ""
wantBuffered = srcData
}

if got := tt.responseWriter.Written(); got != wantStreamed {
t.Errorf("streamed data = %q, want %q", got, wantStreamed)
}
if got := buf.String(); got != wantBuffered {
t.Errorf("buffered data = %q, want %q", got, wantBuffered)
}

if tt.responseWriter.CalledReadFrom() != tt.wantReadFrom {
if tt.wantReadFrom {
t.Errorf("ReadFrom() should have been called")
} else {
t.Errorf("ReadFrom() should not have been called")
}
}
})
}
}