diff --git a/transport/handler_server.go b/transport/handler_server.go index 0489fada52e1..f1f6caf89c26 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -173,7 +173,6 @@ func (ht *serverHandlerTransport) do(fn func()) error { case <-ht.closedCh: return ErrConnClosing } - } } @@ -183,6 +182,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro ht.mu.Unlock() return nil } + ht.streamDone = true ht.mu.Unlock() err := ht.do(func() { ht.writeCommonHeaders(s) @@ -223,9 +223,6 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro } }) close(ht.writes) - ht.mu.Lock() - ht.streamDone = true - ht.mu.Unlock() return err } diff --git a/transport/handler_server_test.go b/transport/handler_server_test.go index 262e6016bd10..06fe813ca912 100644 --- a/transport/handler_server_test.go +++ b/transport/handler_server_test.go @@ -26,6 +26,7 @@ import ( "net/http/httptest" "net/url" "reflect" + "sync" "testing" "time" @@ -390,6 +391,30 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) { } } +func TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) { + st := newHandleStreamTest(t) + handleStream := func(s *Stream) { + if want := "/service/foo.bar"; s.method != want { + t.Errorf("stream method = %q; want %q", s.method, want) + } + st.bodyw.Close() // no body + + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + defer wg.Done() + st.ht.WriteStatus(s, status.New(codes.OK, "")) + }() + } + wg.Wait() + } + st.ht.HandleStreams( + func(s *Stream) { go handleStream(s) }, + func(ctx context.Context, method string) context.Context { return ctx }, + ) +} + func TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) { errDetails := []proto.Message{ &epb.RetryInfo{