Skip to content

Commit

Permalink
transport : wait for goroutines to exit before transport closes (#7666)
Browse files Browse the repository at this point in the history
  • Loading branch information
eshitachandwani authored Oct 10, 2024
1 parent 00b9e14 commit b850ea5
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 17 deletions.
9 changes: 7 additions & 2 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1140,10 +1140,15 @@ func (cc *ClientConn) Close() error {

<-cc.resolverWrapper.serializer.Done()
<-cc.balancerWrapper.serializer.Done()

var wg sync.WaitGroup
for ac := range conns {
ac.tearDown(ErrClientConnClosing)
wg.Add(1)
go func(ac *addrConn) {
defer wg.Done()
ac.tearDown(ErrClientConnClosing)
}(ac)
}
wg.Wait()
cc.addTraceEvent("deleted")
// TraceEvent needs to be called before RemoveEntry, as TraceEvent may add
// trace reference to the entity being deleted, and thus prevent it from being
Expand Down
48 changes: 33 additions & 15 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ type http2Client struct {
writerDone chan struct{} // sync point to enable testing.
// goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor)
// that the server sent GoAway on this transport.
goAway chan struct{}

framer *framer
goAway chan struct{}
keepaliveDone chan struct{} // Closed when the keepalive goroutine exits.
framer *framer
// controlBuf delivers all the control related tasks (e.g., window
// updates, reset streams, and various settings) to the controller.
// Do not access controlBuf with mu held.
Expand Down Expand Up @@ -335,6 +335,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
goAway: make(chan struct{}),
keepaliveDone: make(chan struct{}),
framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize),
fc: &trInFlow{limit: uint32(icwz)},
scheme: scheme,
Expand Down Expand Up @@ -1029,6 +1030,12 @@ func (t *http2Client) Close(err error) {
}
t.cancel()
t.conn.Close()
// Waits for the reader and keepalive goroutines to exit before returning to
// ensure all resources are cleaned up before Close can return.
<-t.readerDone
if t.keepaliveEnabled {
<-t.keepaliveDone
}
channelz.RemoveEntry(t.channelz.ID)
var st *status.Status
if len(goAwayDebugMessage) > 0 {
Expand Down Expand Up @@ -1316,11 +1323,11 @@ func (t *http2Client) handlePing(f *http2.PingFrame) {
t.controlBuf.put(pingAck)
}

func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) error {
t.mu.Lock()
if t.state == closing {
t.mu.Unlock()
return
return nil
}
if f.ErrCode == http2.ErrCodeEnhanceYourCalm && string(f.DebugData()) == "too_many_pings" {
// When a client receives a GOAWAY with error code ENHANCE_YOUR_CALM and debug
Expand All @@ -1332,8 +1339,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
id := f.LastStreamID
if id > 0 && id%2 == 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id))
return
return connectionErrorf(true, nil, "received goaway with non-zero even-numbered stream id: %v", id)
}
// A client can receive multiple GoAways from the server (see
// https://github.com/grpc/grpc-go/issues/1387). The idea is that the first
Expand All @@ -1350,8 +1356,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
// If there are multiple GoAways the first one should always have an ID greater than the following ones.
if id > t.prevGoAwayID {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID))
return
return connectionErrorf(true, nil, "received goaway with stream id: %v, which exceeds stream id of previous goaway: %v", id, t.prevGoAwayID)
}
default:
t.setGoAwayReason(f)
Expand All @@ -1375,8 +1380,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
t.prevGoAwayID = id
if len(t.activeStreams) == 0 {
t.mu.Unlock()
t.Close(connectionErrorf(true, nil, "received goaway and there are no active streams"))
return
return connectionErrorf(true, nil, "received goaway and there are no active streams")
}

streamsToClose := make([]*Stream, 0)
Expand All @@ -1393,6 +1397,7 @@ func (t *http2Client) handleGoAway(f *http2.GoAwayFrame) {
for _, stream := range streamsToClose {
t.closeStream(stream, errStreamDrain, false, http2.ErrCodeNo, statusGoAway, nil, false)
}
return nil
}

// setGoAwayReason sets the value of t.goAwayReason based
Expand Down Expand Up @@ -1628,7 +1633,13 @@ func (t *http2Client) readServerPreface() error {
// network connection. If the server preface is not read successfully, an
// error is pushed to errCh; otherwise errCh is closed with no error.
func (t *http2Client) reader(errCh chan<- error) {
defer close(t.readerDone)
var errClose error
defer func() {
close(t.readerDone)
if errClose != nil {
t.Close(errClose)
}
}()

if err := t.readServerPreface(); err != nil {
errCh <- err
Expand Down Expand Up @@ -1669,7 +1680,7 @@ func (t *http2Client) reader(errCh chan<- error) {
continue
}
// Transport error.
t.Close(connectionErrorf(true, err, "error reading from server: %v", err))
errClose = connectionErrorf(true, err, "error reading from server: %v", err)
return
}
switch frame := frame.(type) {
Expand All @@ -1684,7 +1695,7 @@ func (t *http2Client) reader(errCh chan<- error) {
case *http2.PingFrame:
t.handlePing(frame)
case *http2.GoAwayFrame:
t.handleGoAway(frame)
errClose = t.handleGoAway(frame)
case *http2.WindowUpdateFrame:
t.handleWindowUpdate(frame)
default:
Expand All @@ -1697,6 +1708,13 @@ func (t *http2Client) reader(errCh chan<- error) {

// keepalive running in a separate goroutine makes sure the connection is alive by sending pings.
func (t *http2Client) keepalive() {
var err error
defer func() {
close(t.keepaliveDone)
if err != nil {
t.Close(err)
}
}()
p := &ping{data: [8]byte{}}
// True iff a ping has been sent, and no data has been received since then.
outstandingPing := false
Expand All @@ -1720,7 +1738,7 @@ func (t *http2Client) keepalive() {
continue
}
if outstandingPing && timeoutLeft <= 0 {
t.Close(connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout"))
err = connectionErrorf(true, nil, "keepalive ping failed to receive ACK within timeout")
return
}
t.mu.Lock()
Expand Down
1 change: 1 addition & 0 deletions internal/transport/keepalive_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import (
)

const defaultTestTimeout = 10 * time.Second
const defaultTestShortTimeout = 10 * time.Millisecond

// TestMaxConnectionIdle tests that a server will send GoAway to an idle
// client. An idle client is one who doesn't make any RPC calls for a duration
Expand Down
83 changes: 83 additions & 0 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2781,6 +2781,89 @@ func (s) TestClientSendsAGoAwayFrame(t *testing.T) {
}
}

// readHangingConn is a wrapper around net.Conn that makes the Read() hang when
// Close() is called.
type readHangingConn struct {
net.Conn
readHangConn chan struct{} // Read() hangs until this channel is closed by Close().
closed *atomic.Bool // Set to true when Close() is called.
}

func (hc *readHangingConn) Read(b []byte) (n int, err error) {
n, err = hc.Conn.Read(b)
if hc.closed.Load() {
<-hc.readHangConn // hang the read till we want
}
return n, err
}

func (hc *readHangingConn) Close() error {
hc.closed.Store(true)
return hc.Conn.Close()
}

// Tests that closing a client transport does not return until the reader
// goroutine exits.
func (s) TestClientCloseReturnsAfterReaderCompletes(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

server := setUpServerOnly(t, 0, &ServerConfig{}, normal)
defer server.stop()
addr := resolver.Address{Addr: "localhost:" + server.port}

isReaderHanging := &atomic.Bool{}
readHangConn := make(chan struct{})
copts := ConnectOptions{
Dialer: func(_ context.Context, addr string) (net.Conn, error) {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
return &readHangingConn{Conn: conn, readHangConn: readHangConn, closed: isReaderHanging}, nil
},
ChannelzParent: channelzSubChannel(t),
}

// Create a client transport with a custom dialer that hangs the Read()
// after Close().
ct, err := NewClientTransport(ctx, context.Background(), addr, copts, func(GoAwayReason) {})
if err != nil {
t.Fatalf("Failed to create transport: %v", err)
}

if _, err := ct.NewStream(ctx, &CallHdr{}); err != nil {
t.Fatalf("Failed to open stream: %v", err)
}

// Closing the client transport will result in the underlying net.Conn being
// closed, which will result in readHangingConn.Read() to hang. This will
// stall the exit of the reader goroutine, and will stall client
// transport's Close from returning.
transportClosed := make(chan struct{})
go func() {
ct.Close(errors.New("manually closed by client"))
close(transportClosed)
}()

// Wait for a short duration and ensure that the client transport's Close()
// does not return.
select {
case <-transportClosed:
t.Fatal("Transport closed before reader completed")
case <-time.After(defaultTestShortTimeout):
}

// Closing the channel will unblock the reader goroutine and will ensure
// that the client transport's Close() returns.
close(readHangConn)
select {
case <-transportClosed:
case <-time.After(defaultTestTimeout):
t.Fatal("Timeout when waiting for transport to close")
}
}

// hangingConn is a net.Conn wrapper for testing, simulating hanging connections
// after a GOAWAY frame is sent, of which Write operations pause until explicitly
// signaled or a timeout occurs.
Expand Down

0 comments on commit b850ea5

Please sign in to comment.