diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 7f1b08f628ef..37e089bc8433 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -42,6 +42,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" @@ -102,13 +103,13 @@ type http2Server struct { mu sync.Mutex // guard the following - // drainChan is initialized when Drain() is called the first time. - // After which the server writes out the first GoAway(with ID 2^31-1) frame. - // Then an independent goroutine will be launched to later send the second GoAway. - // During this time we don't want to write another first GoAway(with ID 2^31 -1) frame. - // Thus call to Drain() will be a no-op if drainChan is already initialized since draining is - // already underway. - drainChan chan struct{} + // drainEvent is initialized when Drain() is called the first time. After + // which the server writes out the first GoAway(with ID 2^31-1) frame. Then + // an independent goroutine will be launched to later send the second + // GoAway. During this time we don't want to write another first GoAway(with + // ID 2^31 -1) frame. Thus call to Drain() will be a no-op if drainEvent is + // already initialized since draining is already underway. + drainEvent *grpcsync.Event state transportState activeStreams map[uint32]*Stream // idle is the time instant when the connection went idle. @@ -838,8 +839,8 @@ const ( func (t *http2Server) handlePing(f *http2.PingFrame) { if f.IsAck() { - if f.Data == goAwayPing.data && t.drainChan != nil { - close(t.drainChan) + if f.Data == goAwayPing.data && t.drainEvent != nil { + t.drainEvent.Fire() return } // Maybe it's a BDP ping. @@ -1287,10 +1288,10 @@ func (t *http2Server) RemoteAddr() net.Addr { func (t *http2Server) Drain() { t.mu.Lock() defer t.mu.Unlock() - if t.drainChan != nil { + if t.drainEvent != nil { return } - t.drainChan = make(chan struct{}) + t.drainEvent = grpcsync.NewEvent() t.controlBuf.put(&goAway{code: http2.ErrCodeNo, debugData: []byte{}, headsUp: true}) } @@ -1346,7 +1347,7 @@ func (t *http2Server) outgoingGoAwayHandler(g *goAway) (bool, error) { timer := time.NewTimer(time.Minute) defer timer.Stop() select { - case <-t.drainChan: + case <-t.drainEvent.Done(): case <-timer.C: case <-t.done: return diff --git a/test/goaway_test.go b/test/goaway_test.go index bcd13ae7da66..48b7f0f7c7ac 100644 --- a/test/goaway_test.go +++ b/test/goaway_test.go @@ -363,6 +363,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) { close(ch2) }() // Loop until the server side GoAway signal is propagated to the client. + for { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) if _, err := tc.EmptyCall(ctx, &testpb.Empty{}, grpc.WaitForReady(true)); err != nil { @@ -402,6 +403,7 @@ func testServerMultipleGoAwayPendingRPC(t *testing.T, e env) { if err := stream.CloseSend(); err != nil { t.Fatalf("%v.CloseSend() = %v, want ", stream, err) } + <-ch1 <-ch2 cancel() @@ -707,3 +709,59 @@ func (s) TestGoAwayStreamIDSmallerThanCreatedStreams(t *testing.T) { ct.writeGoAway(1, http2.ErrCodeNo, []byte{}) goAwayWritten.Fire() } + +// TestTwoGoAwayPingFrames tests the scenario where you get two go away ping +// frames from the client during graceful shutdown. This should not crash the +// server. +func (s) TestTwoGoAwayPingFrames(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + defer lis.Close() + s := grpc.NewServer() + defer s.Stop() + go s.Serve(lis) + + conn, err := net.DialTimeout("tcp", lis.Addr().String(), defaultTestTimeout) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + + st := newServerTesterFromConn(t, conn) + st.greet() + pingReceivedClientSide := testutils.NewChannel() + go func() { + for { + f, err := st.readFrame() + if err != nil { + return + } + switch f.(type) { + case *http2.GoAwayFrame: + case *http2.PingFrame: + pingReceivedClientSide.Send(nil) + default: + t.Errorf("server tester received unexpected frame type %T", f) + } + } + }() + gsDone := testutils.NewChannel() + go func() { + s.GracefulStop() + gsDone.Send(nil) + }() + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := pingReceivedClientSide.Receive(ctx); err != nil { + t.Fatalf("Error waiting for ping frame client side from graceful shutdown: %v", err) + } + // Write two goaway pings here. + st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9}) + st.writePing(true, [8]byte{1, 6, 1, 8, 0, 3, 3, 9}) + // Close the conn to finish up the Graceful Shutdown process. + conn.Close() + if _, err := gsDone.Receive(ctx); err != nil { + t.Fatalf("Error waiting for graceful shutdown of the server: %v", err) + } +} diff --git a/test/servertester.go b/test/servertester.go index 9758e8eb6cf8..bf7bd8b214e6 100644 --- a/test/servertester.go +++ b/test/servertester.go @@ -273,3 +273,9 @@ func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) { st.t.Fatalf("Error writing RST_STREAM: %v", err) } } + +func (st *serverTester) writePing(ack bool, data [8]byte) { + if err := st.fr.WritePing(ack, data); err != nil { + st.t.Fatalf("Error writing PING: %v", err) + } +}