Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

middleware: add Discard method to WrapResponseWriter #926

Merged
merged 5 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions middleware/wrap_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

// NewWrapResponseWriter wraps an http.ResponseWriter, returning a proxy that allows you to
// hook into various parts of the response process.
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) wrapResponseWriter {
_, fl := w.(http.Flusher)

bw := basicWriter{ResponseWriter: w}
Expand Down Expand Up @@ -63,6 +63,14 @@ type WrapResponseWriter interface {
Unwrap() http.ResponseWriter
}

type wrapResponseWriter interface {
WrapResponseWriter

// Discard causes all writes to the original ResponseWriter be discarded,
// instead writing only to the tee'd writer if it's set.
Discard()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically a breaking API change -- if there's anyone who implemented this interface externally. But I can't think of any use case why would anyone do that. And it's an easy fix.

So, I'm OK with this change if we bump minor version only. I just wanted to call it out for visibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking of this some more, I think we could achieve backward-compatibility by

  1. Creating a new private interface wrapResponseWriter that extends the original interface with the new Discard() method.
    type wrapResponseWriter interface {
         WrapResponseWriter
        
         // Discard causes all writes to the original ResponseWriter be discarded instead writing only to the tee'd writer if it's set.
         Discard()
    }
  2. Returning it from NewWrapResponseWriter() constructor
    - func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) WrapResponseWriter {
    + func NewWrapResponseWriter(w http.ResponseWriter, protoMajor int) wrapResponseWriter {

I don't see a reason why we'd need to return a public interface in this case. I don't expect anyone re-implementing wrapResponseWriter externally?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I suppose the issue is not people implementing the interface (as they're free to do that) but rather using it to type variables or function arguments. IMO in that case the correct way would rather be to use the http.ResponseWriter or define a custom interface with the methods they need.

Good idea with the private interface, wouldn't it be confusing with the now two similarly named interfaces though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: We decided to update the original interface, as we don't expect anyone using it directly for anything.

}

// basicWriter wraps a http.ResponseWriter that implements the minimal
// http.ResponseWriter interface.
type basicWriter struct {
Expand All @@ -71,6 +79,7 @@ type basicWriter struct {
code int
bytes int
tee io.Writer
discard bool
}

func (b *basicWriter) WriteHeader(code int) {
Expand All @@ -81,15 +90,21 @@ func (b *basicWriter) WriteHeader(code int) {
}
}

func (b *basicWriter) Write(buf []byte) (int, error) {
func (b *basicWriter) Write(buf []byte) (n int, err error) {
b.maybeWriteHeader()
n, err := b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
if !b.discard {
n, err = b.ResponseWriter.Write(buf)
if b.tee != nil {
_, err2 := b.tee.Write(buf[:n])
// Prefer errors generated by the proxied writer.
if err == nil {
err = err2
}
}
} else if b.tee != nil {
n, err = b.tee.Write(buf)
} else {
n, err = io.Discard.Write(buf)
VojtechVitek marked this conversation as resolved.
Show resolved Hide resolved
}
b.bytes += n
return n, err
Expand Down Expand Up @@ -117,6 +132,10 @@ func (b *basicWriter) Unwrap() http.ResponseWriter {
return b.ResponseWriter
}

func (b *basicWriter) Discard() {
b.discard = true
}

// flushWriter ...
type flushWriter struct {
basicWriter
Expand Down
46 changes: 46 additions & 0 deletions middleware/wrap_writer_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package middleware

import (
"bytes"
"net/http/httptest"
"testing"
)
Expand All @@ -22,3 +23,48 @@ func TestHttp2FancyWriterRemembersWroteHeaderWhenFlushed(t *testing.T) {
t.Fatal("want Flush to have set wroteHeader=true")
}
}

func TestBasicWritesTeesWritesWithoutDiscard(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, []byte("hello world"), original.Body.Bytes())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
}

func TestBasicWriterDiscardsWritesToOriginalResponseWriter(t *testing.T) {
t.Run("With Tee", func(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}

var buf bytes.Buffer
wrap.Tee(&buf)
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Body.Len())
assertEqual(t, []byte("hello world"), buf.Bytes())
assertEqual(t, 11, wrap.BytesWritten())
})

t.Run("Without Tee", func(t *testing.T) {
original := httptest.NewRecorder()
wrap := &basicWriter{ResponseWriter: original}
wrap.Discard()

_, err := wrap.Write([]byte("hello world"))
assertNoError(t, err)

assertEqual(t, 0, original.Body.Len())
VojtechVitek marked this conversation as resolved.
Show resolved Hide resolved
assertEqual(t, 11, wrap.BytesWritten())
})
}
Loading