From 05db80f1185060521de1b0c24f06faf9c9003354 Mon Sep 17 00:00:00 2001 From: Doug Fawley Date: Mon, 12 Feb 2024 08:38:58 -0800 Subject: [PATCH] server: wait to close connection until incoming socket is drained (with timeout) (#6977) --- internal/transport/controlbuf.go | 5 ++--- internal/transport/http2_client.go | 8 +++++++- internal/transport/http2_server.go | 23 ++++++++++++++++++++--- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index b330ccedc8ab..83c3829826ae 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -535,8 +535,8 @@ const minBatchSize = 1000 // size is too low to give stream goroutines a chance to fill it up. // // Upon exiting, if the error causing the exit is not an I/O error, run() -// flushes and closes the underlying connection. Otherwise, the connection is -// left open to allow the I/O error to be encountered by the reader instead. +// flushes the underlying connection. The connection is always left open to +// allow different closing behavior on the client and server. func (l *loopyWriter) run() (err error) { defer func() { if l.logger.V(logLevel) { @@ -544,7 +544,6 @@ func (l *loopyWriter) run() (err error) { } if !isIOError(err) { l.framer.writer.Flush() - l.conn.Close() } l.cbuf.finish() }() diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index c33ac5961b2f..eff8799640c6 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -451,7 +451,13 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts } go func() { t.loopy = newLoopyWriter(clientSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) - t.loopy.run() + if err := t.loopy.run(); !isIOError(err) { + // Immediately close the connection, as the loopy writer returns + // when there are no more active streams and we were draining (the + // server sent a GOAWAY). For I/O errors, the reader will hit it + // after draining any remaining incoming data. + t.conn.Close() + } close(t.writerDone) }() return t, nil diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 9bba55422156..3839c1ade27b 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -322,8 +322,24 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, go func() { t.loopy = newLoopyWriter(serverSide, t.framer, t.controlBuf, t.bdpEst, t.conn, t.logger) t.loopy.ssGoAwayHandler = t.outgoingGoAwayHandler - t.loopy.run() + err := t.loopy.run() close(t.loopyWriterDone) + if !isIOError(err) { + // Close the connection if a non-I/O error occurs (for I/O errors + // the reader will also encounter the error and close). Wait 1 + // second before closing the connection, or when the reader is done + // (i.e. the client already closed the connection or a connection + // error occurred). This avoids the potential problem where there + // is unread data on the receive side of the connection, which, if + // closed, would lead to a TCP RST instead of FIN, and the client + // encountering errors. For more info: + // https://github.com/grpc/grpc-go/issues/5358 + select { + case <-t.readerDone: + case <-time.After(time.Second): + } + t.conn.Close() + } }() go t.keepalive() return t, nil @@ -609,8 +625,8 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade // traceCtx attaches trace to ctx and returns the new context. func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) { defer func() { - <-t.loopyWriterDone close(t.readerDone) + <-t.loopyWriterDone }() for { t.controlBuf.throttle() @@ -1325,6 +1341,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { if err := t.framer.fr.WriteGoAway(sid, g.code, g.debugData); err != nil { return false, err } + t.framer.writer.Flush() if retErr != nil { return false, retErr } @@ -1345,7 +1362,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { return false, err } go func() { - timer := time.NewTimer(time.Minute) + timer := time.NewTimer(5 * time.Second) defer timer.Stop() select { case <-t.drainEvent.Done():