Skip to content

Commit

Permalink
net/http: make Transport return Writable Response.Body on protocol sw…
Browse files Browse the repository at this point in the history
…itch

Updates #26937
Updates #17227

Change-Id: I79865938b05c219e1947822e60e4f52bb2604b70
Reviewed-on: https://go-review.googlesource.com/131279
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
  • Loading branch information
bradfitz committed Aug 25, 2018
1 parent 30b080e commit 5416204
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
25 changes: 25 additions & 0 deletions src/net/http/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"golang_org/x/net/http/httpguts"
"io"
"net/textproto"
"net/url"
Expand Down Expand Up @@ -63,6 +64,10 @@ type Response struct {
//
// The Body is automatically dechunked if the server replied
// with a "chunked" Transfer-Encoding.
//
// As of Go 1.12, the Body will be also implement io.Writer
// on a successful "101 Switching Protocols" responses,
// as used by WebSockets and HTTP/2's "h2c" mode.
Body io.ReadCloser

// ContentLength records the length of the associated content. The
Expand Down Expand Up @@ -333,3 +338,23 @@ func (r *Response) closeBody() {
r.Body.Close()
}
}

// bodyIsWritable reports whether the Body supports writing. The
// Transport returns Writable bodies for 101 Switching Protocols
// responses.
// The Transport uses this method to determine whether a persistent
// connection is done being managed from its perspective. Once we
// return a writable response body to a user, the net/http package is
// done managing that connection.
func (r *Response) bodyIsWritable() bool {
_, ok := r.Body.(io.Writer)
return ok
}

// isProtocolSwitch reports whether r is a response to a successful
// protocol upgrade.
func (r *Response) isProtocolSwitch() bool {
return r.StatusCode == StatusSwitchingProtocols &&
r.Header.Get("Upgrade") != "" &&
httpguts.HeaderValuesContainsToken(r.Header["Connection"], "Upgrade")
}
52 changes: 50 additions & 2 deletions src/net/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -1607,6 +1607,11 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte
return err
}

// errCallerOwnsConn is an internal sentinel error used when we hand
// off a writable response.Body to the caller. We use this to prevent
// closing a net.Conn that is now owned by the caller.
var errCallerOwnsConn = errors.New("read loop ending; caller owns writable underlying conn")

func (pc *persistConn) readLoop() {
closeErr := errReadLoopExiting // default value, if not changed below
defer func() {
Expand Down Expand Up @@ -1681,9 +1686,10 @@ func (pc *persistConn) readLoop() {
pc.numExpectedResponses--
pc.mu.Unlock()

bodyWritable := resp.bodyIsWritable()
hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0

if resp.Close || rc.req.Close || resp.StatusCode <= 199 {
if resp.Close || rc.req.Close || resp.StatusCode <= 199 || bodyWritable {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
Expand All @@ -1704,6 +1710,10 @@ func (pc *persistConn) readLoop() {
pc.wroteRequest() &&
tryPutIdleConn(trace)

if bodyWritable {
closeErr = errCallerOwnsConn
}

select {
case rc.ch <- responseAndError{res: resp}:
case <-rc.callerGone:
Expand Down Expand Up @@ -1848,6 +1858,10 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
}
break
}
if resp.isProtocolSwitch() {
resp.Body = newReadWriteCloserBody(pc.br, pc.conn)
}

resp.TLS = pc.tlsState
return
}
Expand All @@ -1874,6 +1888,38 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
}
}

func newReadWriteCloserBody(br *bufio.Reader, rwc io.ReadWriteCloser) io.ReadWriteCloser {
body := &readWriteCloserBody{ReadWriteCloser: rwc}
if br.Buffered() != 0 {
body.br = br
}
return body
}

// readWriteCloserBody is the Response.Body type used when we want to
// give users write access to the Body through the underlying
// connection (TCP, unless using custom dialers). This is then
// the concrete type for a Response.Body on the 101 Switching
// Protocols response, as used by WebSockets, h2c, etc.
type readWriteCloserBody struct {
br *bufio.Reader // used until empty
io.ReadWriteCloser
}

func (b *readWriteCloserBody) Read(p []byte) (n int, err error) {
if b.br != nil {
if n := b.br.Buffered(); len(p) > n {
p = p[:n]
}
n, err = b.br.Read(p)
if b.br.Buffered() == 0 {
b.br = nil
}
return n, err
}
return b.ReadWriteCloser.Read(p)
}

// nothingWrittenError wraps a write errors which ended up writing zero bytes.
type nothingWrittenError struct {
error
Expand Down Expand Up @@ -2193,7 +2239,9 @@ func (pc *persistConn) closeLocked(err error) {
// freelist for http2. That's done by the
// alternate protocol's RoundTripper.
} else {
pc.conn.Close()
if err != errCallerOwnsConn {
pc.conn.Close()
}
close(pc.closech)
}
}
Expand Down
51 changes: 51 additions & 0 deletions src/net/http/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4836,3 +4836,54 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) {
t.Fatal("timeout")
}
}

func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {
setParallel(t)
defer afterTest(t)
done := make(chan struct{})
defer close(done)
cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
defer conn.Close()
io.WriteString(conn, "HTTP/1.1 101 Switching Protocols Hi\r\nConnection: upgRADe\r\nUpgrade: foo\r\n\r\nSome buffered data\n")
bs := bufio.NewScanner(conn)
bs.Scan()
fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text()))
<-done
}))
defer cst.close()

req, _ := NewRequest("GET", cst.ts.URL, nil)
req.Header.Set("Upgrade", "foo")
req.Header.Set("Connection", "upgrade")
res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != 101 {
t.Fatalf("expected 101 switching protocols; got %v, %v", res.Status, res.Header)
}
rwc, ok := res.Body.(io.ReadWriteCloser)
if !ok {
t.Fatalf("expected a ReadWriteCloser; got a %T", res.Body)
}
defer rwc.Close()
bs := bufio.NewScanner(rwc)
if !bs.Scan() {
t.Fatalf("expected readable input")
}
if got, want := bs.Text(), "Some buffered data"; got != want {
t.Errorf("read %q; want %q", got, want)
}
io.WriteString(rwc, "echo\n")
if !bs.Scan() {
t.Fatalf("expected another line")
}
if got, want := bs.Text(), "ECHO"; got != want {
t.Errorf("read %q; want %q", got, want)
}
}

0 comments on commit 5416204

Please sign in to comment.