Skip to content

Commit

Permalink
internal/transport: convert ConnectionError to Unavailable status…
Browse files Browse the repository at this point in the history
… when writing headers (#6891)
  • Loading branch information
mustafasen81 authored Jan 10, 2024
1 parent e7e400b commit 6ce73bf
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
7 changes: 6 additions & 1 deletion internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,12 @@ func (t *http2Server) WriteHeader(s *Stream, md metadata.MD) error {
}
}
if err := t.writeHeaderLocked(s); err != nil {
return status.Convert(err).Err()
switch e := err.(type) {
case ConnectionError:
return status.Error(codes.Unavailable, e.Desc)
default:
return status.Convert(err).Err()
}
}
return nil
}
Expand Down
65 changes: 65 additions & 0 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -2136,6 +2137,70 @@ func (s) TestHeadersHTTPStatusGRPCStatus(t *testing.T) {
}
}

func (s) TestWriteHeaderConnectionError(t *testing.T) {
server, client, cancel := setUp(t, 0, notifyCall)
defer cancel()
defer server.stop()

waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()

if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})

server.mu.Lock()

if len(server.conns) != 1 {
t.Fatalf("Server has %d connections from the client, want 1", len(server.conns))
}

// Get the server transport for the connecton to the client.
var serverTransport *http2Server
for k := range server.conns {
serverTransport = k.(*http2Server)
}
notifyChan := make(chan struct{})
server.h.notify = notifyChan
server.mu.Unlock()

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
cstream, err := client.NewStream(ctx, &CallHdr{})
if err != nil {
t.Fatalf("Client failed to create first stream. Err: %v", err)
}

<-notifyChan // Wait for server stream to be established.
var sstream *Stream
// Access stream on the server.
serverTransport.mu.Lock()
for _, v := range serverTransport.activeStreams {
if v.id == cstream.id {
sstream = v
}
}
serverTransport.mu.Unlock()
if sstream == nil {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream.id)
}

client.Close(fmt.Errorf("closed manually by test"))

// Wait for server transport to be closed.
<-serverTransport.done

// Write header on a closed server transport.
err = serverTransport.WriteHeader(sstream, metadata.MD{})
st := status.Convert(err)
if st.Code() != codes.Unavailable {
t.Fatalf("WriteHeader() failed with status code %s, want %s", st.Code(), codes.Unavailable)
}
}

func (s) TestPingPong1B(t *testing.T) {
runPingPongTest(t, 1)
}
Expand Down

0 comments on commit 6ce73bf

Please sign in to comment.