Skip to content

Commit

Permalink
cherry-pick: transport: add timeout for writing GOAWAY on http2Client…
Browse files Browse the repository at this point in the history
….Close() #7371 (#7540)
  • Loading branch information
aranjans authored Aug 20, 2024
1 parent 63853fd commit 35e915e
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 5 deletions.
17 changes: 15 additions & 2 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ import (
// atomically.
var clientConnectionCounter uint64

var goAwayLoopyWriterTimeout = 5 * time.Second

var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool))

// http2Client implements the ClientTransport interface with HTTP2.
Expand Down Expand Up @@ -983,6 +985,7 @@ func (t *http2Client) closeStream(s *Stream, err error, rst bool, rstCode http2.
// only once on a transport. Once it is called, the transport should not be
// accessed anymore.
func (t *http2Client) Close(err error) {
t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10))
t.mu.Lock()
// Make sure we only close once.
if t.state == closing {
Expand All @@ -1006,10 +1009,20 @@ func (t *http2Client) Close(err error) {
t.kpDormancyCond.Signal()
}
t.mu.Unlock()

// Per HTTP/2 spec, a GOAWAY frame must be sent before closing the
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY.
// connection. See https://httpwg.org/specs/rfc7540.html#GOAWAY. It
// also waits for loopyWriter to be closed with a timer to avoid the
// long blocking in case the connection is blackholed, i.e. TCP is
// just stuck.
t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte("client transport shutdown"), closeConn: err})
<-t.writerDone
timer := time.NewTimer(goAwayLoopyWriterTimeout)
defer timer.Stop()
select {
case <-t.writerDone: // success
case <-timer.C:
t.logger.Infof("Failed to write a GOAWAY frame as part of connection close after %s. Giving up and closing the transport.", goAwayLoopyWriterTimeout)
}
t.cancel()
t.conn.Close()
channelz.RemoveEntry(t.channelz.ID)
Expand Down
72 changes: 69 additions & 3 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -2424,7 +2425,7 @@ func (s) TestClientHandshakeInfo(t *testing.T) {
TransportCredentials: creds,
ChannelzParent: channelzSubChannel(t),
}
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {})
if err != nil {
t.Fatalf("NewClientTransport(): %v", err)
}
Expand Down Expand Up @@ -2465,7 +2466,7 @@ func (s) TestClientHandshakeInfoDialer(t *testing.T) {
Dialer: dialer,
ChannelzParent: channelzSubChannel(t),
}
tr, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
tr, err := NewClientTransport(ctx, ctx, addr, copts, func(GoAwayReason) {})
if err != nil {
t.Fatalf("NewClientTransport(): %v", err)
}
Expand Down Expand Up @@ -2725,7 +2726,7 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
}
}()

ct, err := NewClientTransport(ctx, context.Background(), resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
ct, err := NewClientTransport(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, ConnectOptions{}, func(GoAwayReason) {})
if err != nil {
t.Fatalf("Error while creating client transport: %v", err)
}
Expand All @@ -2746,3 +2747,68 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
t.Errorf("Context timed out")
}
}

// hangingConn is a net.Conn wrapper for testing, simulating hanging connections
// after a GOAWAY frame is sent, of which Write operations pause until explicitly
// signaled or a timeout occurs.
type hangingConn struct {
net.Conn
hangConn chan struct{}
startHanging *atomic.Bool
}

func (hc *hangingConn) Write(b []byte) (n int, err error) {
n, err = hc.Conn.Write(b)
if hc.startHanging.Load() {
<-hc.hangConn
}
return n, err
}

// Tests the scenario where a client transport is closed and writing of the
// GOAWAY frame as part of the close does not complete because of a network
// hang. The test verifies that the client transport is closed without waiting
// for too long.
func (s) TestClientCloseReturnsEarlyWhenGoAwayWriteHangs(t *testing.T) {
// Override timer for writing GOAWAY to 0 so that the connection write
// always times out. It is equivalent of real network hang when conn
// write for goaway doesn't finish in specified deadline
origGoAwayLoopyTimeout := goAwayLoopyWriterTimeout
goAwayLoopyWriterTimeout = time.Millisecond
defer func() {
goAwayLoopyWriterTimeout = origGoAwayLoopyTimeout
}()

// Create the server set up.
connectCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
defer server.stop()
addr := resolver.Address{Addr: "localhost:" + server.port}
isGreetingDone := &atomic.Bool{}
hangConn := make(chan struct{})
defer close(hangConn)
dialer := func(_ context.Context, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &hangingConn{Conn: conn, hangConn: hangConn, startHanging: isGreetingDone}, nil
}
copts := ConnectOptions{Dialer: dialer}
copts.ChannelzParent = channelzSubChannel(t)
// Create client transport with custom dialer
ct, connErr := NewClientTransport(connectCtx, context.Background(), addr, copts, func(GoAwayReason) {})
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
t.Fatalf("Failed to open stream: %v", err)
}

isGreetingDone.Store(true)
ct.Close(errors.New("manually closed by client"))
}

0 comments on commit 35e915e

Please sign in to comment.