Skip to content

Commit

Permalink
client: fix ClientStream.Header() behavior (#6557)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfawley authored Aug 18, 2023
1 parent 8a2c220 commit fe1519e
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 70 deletions.
38 changes: 38 additions & 0 deletions binarylog/binarylog_end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/binarylog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
iblog "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -1059,3 +1061,39 @@ func (s) TestServerBinaryLogFullDuplexError(t *testing.T) {
t.Fatal(err)
}
}

// TestCanceledStatus ensures a server that responds with a Canceled status has
// its trailers logged appropriately and is not treated as a canceled RPC.
func (s) TestCanceledStatus(t *testing.T) {
defer testSink.clear()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

const statusMsgWant = "server returned Canceled"
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
grpc.SetTrailer(ctx, metadata.Pairs("key", "value"))
return nil, status.Error(codes.Canceled, statusMsgWant)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled {
t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err)
}

got := testSink.logEntries(true)
last := got[len(got)-1]
if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER ||
last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) ||
last.GetTrailer().GetStatusMessage() != statusMsgWant ||
len(last.GetTrailer().GetMetadata().GetEntry()) != 1 ||
last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" ||
string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" {
t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got)
}
}
31 changes: 15 additions & 16 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,30 +1505,28 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return
}

isHeader := false

// If headerChan hasn't been closed yet
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
if !endStream {
// HEADERS frame block carries a Response-Headers.
isHeader = true
// For headers, set them in s.header and close headerChan. For trailers or
// trailers-only, closeStream will set the trailers and close headerChan as
// needed.
if !endStream {
// If headerChan hasn't been closed yet (expected, given we checked it
// above, but something else could have potentially closed the whole
// stream).
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
// These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed
// headerChan which we'll close after setting this.
s.recvCompress = recvCompress
if len(mdata) > 0 {
s.header = mdata
}
} else {
// HEADERS frame block carries a Trailers-Only.
s.noHeaders = true
close(s.headerChan)
}
close(s.headerChan)
}

for _, sh := range t.statsHandlers {
if isHeader {
if !endStream {
inHeader := &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
Expand All @@ -1554,9 +1552,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
statusGen = status.New(rawStatusCode, grpcMessage)
}

// if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true)
// If client received END_STREAM from server while stream was still active,
// send RST_STREAM.
rstStream := s.getState() == streamActive
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true)
}

// readServerPreface reads and handles the initial settings frame from the
Expand Down
10 changes: 1 addition & 9 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ import (
"google.golang.org/grpc/tap"
)

// ErrNoHeaders is used as a signal that a trailers only response was received,
// and is not a real error.
var ErrNoHeaders = errors.New("stream has no headers")

const logLevel = 2

type bufferPool struct {
Expand Down Expand Up @@ -390,14 +386,10 @@ func (s *Stream) Header() (metadata.MD, error) {
}
s.waitOnHeader()

if !s.headerValid {
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}

if s.noHeaders {
return nil, ErrNoHeaders
}

return s.header.Copy(), nil
}

Expand Down
7 changes: 5 additions & 2 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,15 +867,18 @@ func Errorf(c codes.Code, format string, a ...any) error {
return status.Errorf(c, format, a...)
}

var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error())
var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())

// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
switch err {
case nil, io.EOF:
return err
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
return errContextDeadline
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
return errContextCanceled
case io.ErrUnexpectedEOF:
return status.Error(codes.Internal, err.Error())
}
Expand Down
65 changes: 30 additions & 35 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,23 +789,23 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())

func (cs *clientStream) Header() (metadata.MD, error) {
var m metadata.MD
noHeader := false
err := cs.withRetry(func(a *csAttempt) error {
var err error
m, err = a.s.Header()
if err == transport.ErrNoHeaders {
noHeader = true
return nil
}
return toRPCErr(err)
}, cs.commitAttemptLocked)

if m == nil && err == nil {
// The stream ended with success. Finish the clientStream.
err = io.EOF
}

if err != nil {
cs.finish(err)
return nil, err
}

if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && !noHeader {
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil {
// Only log if binary log is on and header has not been logged, and
// there is actually headers to log.
logEntry := &binarylog.ServerHeader{
Expand All @@ -821,6 +821,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
binlog.Log(cs.ctx, logEntry)
}
}

return m, nil
}

Expand Down Expand Up @@ -929,24 +930,6 @@ func (cs *clientStream) RecvMsg(m any) error {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)

if len(cs.binlogs) != 0 {
// finish will not log Trailer. Log Trailer here.
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if logEntry.Err == io.EOF {
logEntry.Err = nil
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
return err
}
Expand Down Expand Up @@ -1002,18 +985,30 @@ func (cs *clientStream) finish(err error) {
}
}
}

cs.mu.Unlock()
// For binary logging. only log cancel in finish (could be caused by RPC ctx
// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
//
// Only one of cancel or trailer needs to be logged. In the cases where
// users don't call RecvMsg, users must have already canceled the RPC.
if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled {
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
// Only one of cancel or trailer needs to be logged.
if len(cs.binlogs) != 0 {
switch err {
case errContextCanceled, errContextDeadline, ErrClientConnClosing:
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
}
default:
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
if err == nil {
Expand Down
9 changes: 4 additions & 5 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6328,12 +6328,11 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return nil
}
_, err := stream.Recv()
if err == io.EOF {
return nil
}
return status.Errorf(codes.Unknown, "expected client to call CloseSend")
},
}

Expand Down
20 changes: 17 additions & 3 deletions test/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ func (s) TestRetryStreaming(t *testing.T) {
return nil
}
}
sHdr := func() serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return stream.SendHeader(metadata.Pairs("test_header", "test_value"))
}
}
sRes := func(b byte) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
msg := res(b)
Expand All @@ -222,7 +227,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
sErr := func(c codes.Code) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return status.New(c, "").Err()
return status.New(c, "this is a test error").Err()
}
}
sCloseSend := func() serverOp {
Expand Down Expand Up @@ -270,7 +275,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
cErr := func(c codes.Code) clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
want := status.New(c, "").Err()
want := status.New(c, "this is a test error").Err()
if c == codes.OK {
want = io.EOF
}
Expand Down Expand Up @@ -309,6 +314,11 @@ func (s) TestRetryStreaming(t *testing.T) {
cHdr := func() clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
_, err := stream.Header()
if err == io.EOF {
// The stream ended successfully; convert to nil to avoid
// erroring the test case.
err = nil
}
return err
}
}
Expand Down Expand Up @@ -362,9 +372,13 @@ func (s) TestRetryStreaming(t *testing.T) {
sReq(1), sRes(3), sErr(codes.Unavailable),
},
clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)},
}, {
desc: "Retry via ClientStream.Header()",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1)},
clientOps: []clientOp{cReq(1), cHdr() /* this should cause a retry */, cErr(codes.OK)},
}, {
desc: "No retry after header",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
serverOps: []serverOp{sReq(1), sHdr(), sErr(codes.Unavailable)},
clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)},
}, {
desc: "No retry after context",
Expand Down

0 comments on commit fe1519e

Please sign in to comment.