Skip to content

Commit

Permalink
net/http: reverseproxy: forward 1xx responses
Browse files Browse the repository at this point in the history
Support for 1xx responses has recently been merged in
net/http (#42597).

As discussed in this CL
(https://go-review.googlesource.com/c/go/+/269997/comments/1ff70bef_c25a829a),
support for forwarding 1xx responses in ReverseProxy has been extracted
in this separate patch.

According to RFC 7231, "a proxy MUST forward 1xx responses unless the
proxy itself requested the generation of the 1xx response".
Consequently, all received 1xx responses are automatically forwarded as long as the
underlying transport supports ClientTrace.Got1xxResponse.

Fixes #26088
Fixes #51914
  • Loading branch information
dunglas committed May 31, 2022
1 parent cfd202c commit aa23135
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
30 changes: 29 additions & 1 deletion src/net/http/httputil/reverseproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"mime"
"net"
"net/http"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
Expand All @@ -40,6 +41,9 @@ import (
// To prevent IP spoofing, be sure to delete any pre-existing
// X-Forwarded-For header coming from the client or
// an untrusted proxy.
//
// 1xx responses are forwarded to the client if the underlying
// transport supports ClientTrace.Got1xxResponse.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
Expand Down Expand Up @@ -307,6 +311,23 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}

var headerSet bool
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
h := rw.Header()
copyHeader(h, http.Header(header))
rw.WriteHeader(code)

// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k, _ := range h {
h.Del(k)
}

return nil
},
}
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))

res, err := transport.RoundTrip(outreq)
if err != nil {
p.getErrorHandler()(rw, outreq, err)
Expand All @@ -332,7 +353,14 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

copyHeader(rw.Header(), res.Header)
h := rw.Header()
if headerSet {
for k, _ := range h {
h.Del(k)
}
}

copyHeader(h, res.Header)

// The "Trailer" header isn't included in the Transport's response,
// at least for *http.Transport. Build it up from Trailer.
Expand Down
76 changes: 76 additions & 0 deletions src/net/http/httputil/reverseproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ import (
"log"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"net/http/internal/ascii"
"net/textproto"
"net/url"
"os"
"reflect"
Expand Down Expand Up @@ -1537,3 +1539,77 @@ func TestJoinURLPath(t *testing.T) {
}
}
}

func Test1xxResponses(t *testing.T) {
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Add("Link", "</style.css>; rel=preload; as=style")
h.Add("Link", "</script.js>; rel=preload; as=script")
w.WriteHeader(http.StatusEarlyHints)

h.Add("Link", "</foo.js>; rel=preload; as=script")
w.WriteHeader(http.StatusProcessing)

w.Write([]byte("Hello"))
}))
defer backend.Close()
backendURL, err := url.Parse(backend.URL)
if err != nil {
t.Fatal(err)
}
proxyHandler := NewSingleHostReverseProxy(backendURL)
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
frontend := httptest.NewServer(proxyHandler)
defer frontend.Close()
frontendClient := frontend.Client()

checkLinkHeaders := func(t *testing.T, expected, got []string) {
t.Helper()

if len(expected) != len(got) {
t.Errorf("Expected %d link headers; got %d", len(expected), len(got))
}

for i := range expected {
if expected[i] != got[i] {
t.Errorf("Expected %q link header; got %q", expected[i], got[i])
}
}
}

var respCounter uint8
trace := &httptrace.ClientTrace{
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
switch code {
case http.StatusEarlyHints:
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script"}, header["Link"])
case http.StatusProcessing:
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, header["Link"])
default:
t.Error("Unexpected 1xx response")
}

respCounter++

return nil
},
}
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), "GET", frontend.URL, nil)

res, err := frontendClient.Do(req)
if err != nil {
t.Fatalf("Get: %v", err)
}

defer res.Body.Close()

if respCounter != 2 {
t.Errorf("Excpected 2 1xx responses; got %d", respCounter)
}
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])

body, _ := io.ReadAll(res.Body)
if string(body) != "Hello" {
t.Errorf("Read body %q; want Hello", body)
}
}

0 comments on commit aa23135

Please sign in to comment.