diff --git a/http2/transport.go b/http2/transport.go index da53e83cb2..b9632380e7 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1268,22 +1268,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { cancelRequest := func(cs *clientStream, err error) error { cs.cc.mu.Lock() - cs.abortStreamLocked(err) bodyClosed := cs.reqBodyClosed - if cs.ID != 0 { - // This request may have failed because of a problem with the connection, - // or for some unrelated reason. (For example, the user might have canceled - // the request without waiting for a response.) Mark the connection as - // not reusable, since trying to reuse a dead connection is worse than - // unnecessarily creating a new one. - // - // If cs.ID is 0, then the request was never allocated a stream ID and - // whatever went wrong was unrelated to the connection. We might have - // timed out waiting for a stream slot when StrictMaxConcurrentStreams - // is set, for example, in which case retrying on a different connection - // will not help. - cs.cc.doNotReuse = true - } cs.cc.mu.Unlock() // Wait for the request body to be closed. // @@ -1318,11 +1303,14 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) { return handleResponseHeaders() default: waitDone() - return nil, cancelRequest(cs, cs.abortErr) + return nil, cs.abortErr } case <-ctx.Done(): - return nil, cancelRequest(cs, ctx.Err()) + err := ctx.Err() + cs.abortStream(err) + return nil, cancelRequest(cs, err) case <-cs.reqCancel: + cs.abortStream(errRequestCanceled) return nil, cancelRequest(cs, errRequestCanceled) } } diff --git a/http2/transport_test.go b/http2/transport_test.go index 53999f6a04..d3156208cf 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -3873,10 +3873,11 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { } } - server := func(count int, ct *clientTester) { + server := func(_ int, ct *clientTester) { ct.greet() var buf bytes.Buffer enc := hpack.NewEncoder(&buf) + var count int for { f, err := ct.fr.ReadFrame() if err != nil { @@ -3897,6 +3898,7 @@ func TestTransportRetryAfterRefusedStream(t *testing.T) { t.Errorf("headers should have END_HEADERS be ended: %v", f) return } + count++ if count == 1 { ct.fr.WriteRSTStream(f.StreamID, ErrCodeRefusedStream) } else { @@ -6364,97 +6366,3 @@ func TestTransportSlowClose(t *testing.T) { } res.Body.Close() } - -type blockReadConn struct { - net.Conn - blockc chan struct{} -} - -func (c *blockReadConn) Read(b []byte) (n int, err error) { - <-c.blockc - return c.Conn.Read(b) -} - -func TestTransportReuseAfterError(t *testing.T) { - serverReqc := make(chan struct{}, 3) - st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { - serverReqc <- struct{}{} - }, optOnlyServer) - defer st.Close() - - var ( - unblockOnce sync.Once - blockc = make(chan struct{}) - connCountMu sync.Mutex - connCount int - ) - tr := &Transport{ - TLSClientConfig: tlsConfigInsecure, - DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { - // The first connection dialed will block on reads until blockc is closed. - connCountMu.Lock() - defer connCountMu.Unlock() - connCount++ - conn, err := tls.Dial(network, addr, cfg) - if err != nil { - return nil, err - } - if connCount == 1 { - return &blockReadConn{ - Conn: conn, - blockc: blockc, - }, nil - } - return conn, nil - }, - } - defer tr.CloseIdleConnections() - defer unblockOnce.Do(func() { - // Ensure that reads on blockc are unblocked if we return early. - close(blockc) - }) - - req, _ := http.NewRequest("GET", st.ts.URL, nil) - - // Request 1 is made on conn 1. - // Reading the response will block. - // Wait until the server receives the request, and continue. - req1c := make(chan struct{}) - go func() { - defer close(req1c) - res1, err := tr.RoundTrip(req.Clone(context.Background())) - if err != nil { - t.Errorf("request 1: %v", err) - } else { - res1.Body.Close() - } - }() - <-serverReqc - - // Request 2 is also made on conn 1. - // Reading the response will block. - // The request is canceled once the server receives it. - // Conn 1 should now be flagged as unfit for reuse. - req2Ctx, cancel := context.WithCancel(context.Background()) - go func() { - <-serverReqc - cancel() - }() - _, err := tr.RoundTrip(req.Clone(req2Ctx)) - if err == nil { - t.Errorf("request 2 unexpectedly succeeded (want cancel)") - } - - // Request 3 is made on a new conn, and succeeds. - res3, err := tr.RoundTrip(req.Clone(context.Background())) - if err != nil { - t.Fatalf("request 3: %v", err) - } - res3.Body.Close() - - // Unblock conn 1, and verify that request 1 completes. - unblockOnce.Do(func() { - close(blockc) - }) - <-req1c -}