diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index f62ff72d624a..e8142a7a69c9 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -59,6 +59,8 @@ import ( // atomically. var clientConnectionCounter uint64 +var goAwayLoopyWriterTimeout = 5 * time.Second + var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) // http2Client implements the ClientTransport interface with HTTP2. @@ -983,6 +985,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2. // only once on a transport. Once it is called, the transport should not be // accessed anymore. func (t *http2Client) Close(err error) { + t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) t.mu.Lock() // Make sure we only close once. if t.state == closing { @@ -1006,10 +1009,20 @@ func (t *http2Client) Close(err error) { t.kpDormancyCond.Signal() } t.mu.Unlock() + // Per HTTP/2 spec, a GOAWAY frame must be sent before closing the - // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. + // connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It + // also waits for loopyWriter to be closed with a timer to avoid the + // long blocking in case the connection is blackholed, i.e. TCP is + // just stuck. t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err}) - <-t.writerDone + timer := time.NewTimer(goAwayLoopyWriterTimeout) + defer timer.Stop() + select { + case <-t.writerDone: // success + case <-timer.C: + t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout) + } t.cancel() t.conn.Close() channelz.RemoveEntry(t.channelz.ID) diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index b4082ff47d23..3292700c8a4d 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -32,6 +32,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "testing" "time" @@ -2424,7 +2425,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) { TransportCredentials: creds, ChannelzParent: channelzSubChannel(t), } - tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) + tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) } @@ -2465,7 +2466,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) { Dialer: dialer, ChannelzParent: channelzSubChannel(t), } - tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {}) + tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {}) if err != nil { t.Fatalf("NewClientTransport(): %v", err) } @@ -2725,7 +2726,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { } }() - ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) + ct, err := NewClientTransport(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {}) if err != nil { t.Fatalf("Error while creating client transport: %v", err) } @@ -2746,3 +2747,68 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) { t.Errorf("Context timed out") } } + +// hangingConn is a net.Conn wrapper for testing, simulating hanging connections +// after a GOAWAY frame is sent, of which Write operations pause until explicitly +// signaled or a timeout occurs. +type hangingConn struct { + net.Conn + hangConn chan struct{} + startHanging *atomic.Bool +} + +func (hc *hangingConn) Write(b []byte) (n int, err error) { + n, err = hc.Conn.Write(b) + if hc.startHanging.Load() { + <-hc.hangConn + } + return n, err +} + +// Tests the scenario where a client transport is closed and writing of the +// GOAWAY frame as part of the close does not complete because of a network +// hang. The test verifies that the client transport is closed without waiting +// for too long. +func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) { + // Override timer for writing GOAWAY to 0 so that the connection write + // always times out. It is equivalent of real network hang when conn + // write for goaway doesn't finish in specified deadline + origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout + goAwayLoopyWriterTimeout = time.Millisecond + defer func() { + goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout + }() + + // Create the server set up. + connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + server := setUpServerOnly(t, 0, &ServerConfig{}, normal) + defer server.stop() + addr := resolver.Address{Addr: "localhost:" + server.port} + isGreetingDone := &atomic.Bool{} + hangConn := make(chan struct{}) + defer close(hangConn) + dialer := func(_ context.Context, addr string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil + } + copts := ConnectOptions{Dialer: dialer} + copts.ChannelzParent = channelzSubChannel(t) + // Create client transport with custom dialer + ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {}) + if connErr != nil { + t.Fatalf("failed to create transport: %v", connErr) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil { + t.Fatalf("Failed to open stream: %v", err) + } + + isGreetingDone.Store(true) + ct.Close(errors.New("manually closed by client")) +}