Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

client: fix ClientStream.Header() behavior #6557

Merged
merged 5 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions binarylog/binarylog_end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/binarylog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
iblog "google.golang.org/grpc/internal/binarylog"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

Expand Down Expand Up @@ -1059,3 +1061,39 @@ func (s) TestServerBinaryLogFullDuplexError(t *testing.T) {
t.Fatal(err)
}
}

// TestCanceledStatus ensures a server that responds with a Canceled status has
// its trailers logged appropriately and is not treated as a canceled RPC.
func (s) TestCanceledStatus(t *testing.T) {
zasweq marked this conversation as resolved.
Show resolved Hide resolved
defer testSink.clear()

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

const statusMsgWant = "server returned Canceled"
ss := &stubserver.StubServer{
UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
grpc.SetTrailer(ctx, metadata.Pairs("key", "value"))
return nil, status.Error(codes.Canceled, statusMsgWant)
},
}
if err := ss.Start(nil); err != nil {
t.Fatalf("Error starting endpoint server: %v", err)
}
defer ss.Stop()

if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); status.Code(err) != codes.Canceled {
t.Fatalf("Received unexpected error from UnaryCall: %v; want Canceled", err)
}

got := testSink.logEntries(true)
last := got[len(got)-1]
if last.Type != binlogpb.GrpcLogEntry_EVENT_TYPE_SERVER_TRAILER ||
last.GetTrailer().GetStatusCode() != uint32(codes.Canceled) ||
last.GetTrailer().GetStatusMessage() != statusMsgWant ||
len(last.GetTrailer().GetMetadata().GetEntry()) != 1 ||
last.GetTrailer().GetMetadata().GetEntry()[0].GetKey() != "key" ||
string(last.GetTrailer().GetMetadata().GetEntry()[0].GetValue()) != "value" {
Comment on lines +1091 to +1096
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a thought, could you use the eventual proto.Equal call in this helper as other parts of the test:

func equalLogEntry(entries ...*binlogpb.GrpcLogEntry) (equal bool) {
?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this test is to check the trailer log entry. I don't want to have to build & validate everything just for that, which equalLogEntry requires. I think this is fine and I'm not a big fan of the style of tests in this file that have a lot of innate knowledge about different parts of the code spread throughout.

t.Fatalf("Got binary log: %+v; want last entry is server trailing with status Canceled", got)
zasweq marked this conversation as resolved.
Show resolved Hide resolved
}
}
31 changes: 15 additions & 16 deletions internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1505,30 +1505,28 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
return
}

isHeader := false

// If headerChan hasn't been closed yet
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
if !endStream {
// HEADERS frame block carries a Response-Headers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: keep this comment?

isHeader = true
// For headers, set them in s.header and close headerChan. For trailers or
// trailers-only, closeStream will set the trailers and close headerChan as
// needed.
if !endStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional (use some/all/none): mention this is headers flow, and that for trailers flow, header chan gets closed in closeStream() which writes status before unblocking the header chan, letting client stream read the status after it's block on header chan (in csAttempt).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

// If headerChan hasn't been closed yet (expected, given we checked it
// above, but something else could have potentially closed the whole
// stream).
if atomic.CompareAndSwapUint32(&s.headerChanClosed, 0, 1) {
s.headerValid = true
// These values can be set without any synchronization because
// stream goroutine will read it only after seeing a closed
// headerChan which we'll close after setting this.
s.recvCompress = recvCompress
if len(mdata) > 0 {
s.header = mdata
}
} else {
// HEADERS frame block carries a Trailers-Only.
s.noHeaders = true
close(s.headerChan)
}
close(s.headerChan)
}

for _, sh := range t.statsHandlers {
if isHeader {
if !endStream {
inHeader := &stats.InHeader{
Client: true,
WireLength: int(frame.Header().Length),
Expand All @@ -1554,9 +1552,10 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
statusGen = status.New(rawStatusCode, grpcMessage)
}

// if client received END_STREAM from server while stream was still active, send RST_STREAM
rst := s.getState() == streamActive
t.closeStream(s, io.EOF, rst, http2.ErrCodeNo, statusGen, mdata, true)
// If client received END_STREAM from server while stream was still active,
// send RST_STREAM.
rstStream := s.getState() == streamActive
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true)
}

// readServerPreface reads and handles the initial settings frame from the
Expand Down
10 changes: 1 addition & 9 deletions internal/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@ import (
"google.golang.org/grpc/tap"
)

// ErrNoHeaders is used as a signal that a trailers only response was received,
// and is not a real error.
var ErrNoHeaders = errors.New("stream has no headers")

const logLevel = 2

type bufferPool struct {
Expand Down Expand Up @@ -390,14 +386,10 @@ func (s *Stream) Header() (metadata.MD, error) {
}
s.waitOnHeader()

if !s.headerValid {
if !s.headerValid || s.noHeaders {
return nil, s.status.Err()
}

if s.noHeaders {
return nil, ErrNoHeaders
}

return s.header.Copy(), nil
}

Expand Down
7 changes: 5 additions & 2 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -867,15 +867,18 @@ func Errorf(c codes.Code, format string, a ...any) error {
return status.Errorf(c, format, a...)
}

var errContextCanceled = status.Error(codes.Canceled, context.Canceled.Error())
var errContextDeadline = status.Error(codes.DeadlineExceeded, context.DeadlineExceeded.Error())

// toRPCErr converts an error into an error from the status package.
func toRPCErr(err error) error {
switch err {
case nil, io.EOF:
return err
case context.DeadlineExceeded:
return status.Error(codes.DeadlineExceeded, err.Error())
return errContextDeadline
case context.Canceled:
return status.Error(codes.Canceled, err.Error())
return errContextCanceled
case io.ErrUnexpectedEOF:
return status.Error(codes.Internal, err.Error())
}
Expand Down
65 changes: 30 additions & 35 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -789,23 +789,23 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func())

func (cs *clientStream) Header() (metadata.MD, error) {
var m metadata.MD
noHeader := false
err := cs.withRetry(func(a *csAttempt) error {
var err error
m, err = a.s.Header()
if err == transport.ErrNoHeaders {
noHeader = true
return nil
}
return toRPCErr(err)
}, cs.commitAttemptLocked)

if m == nil && err == nil {
// The stream ended with success. Finish the clientStream.
err = io.EOF
}

if err != nil {
cs.finish(err)
return nil, err
}

if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && !noHeader {
if len(cs.binlogs) != 0 && !cs.serverHeaderBinlogged && m != nil {
// Only log if binary log is on and header has not been logged, and
// there is actually headers to log.
logEntry := &binarylog.ServerHeader{
Expand All @@ -821,6 +821,7 @@ func (cs *clientStream) Header() (metadata.MD, error) {
binlog.Log(cs.ctx, logEntry)
}
}

return m, nil
}

Expand Down Expand Up @@ -929,24 +930,6 @@ func (cs *clientStream) RecvMsg(m any) error {
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)

if len(cs.binlogs) != 0 {
// finish will not log Trailer. Log Trailer here.
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if logEntry.Err == io.EOF {
logEntry.Err = nil
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
return err
}
Expand Down Expand Up @@ -1002,18 +985,30 @@ func (cs *clientStream) finish(err error) {
}
}
}

cs.mu.Unlock()
// For binary logging. only log cancel in finish (could be caused by RPC ctx
// canceled or ClientConn closed). Trailer will be logged in RecvMsg.
//
// Only one of cancel or trailer needs to be logged. In the cases where
// users don't call RecvMsg, users must have already canceled the RPC.
if len(cs.binlogs) != 0 && status.Code(err) == codes.Canceled {
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
// Only one of cancel or trailer needs to be logged.
if len(cs.binlogs) != 0 {
switch err {
case errContextCanceled, errContextDeadline, ErrClientConnClosing:
c := &binarylog.Cancel{
OnClientSide: true,
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, c)
}
default:
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
for _, binlog := range cs.binlogs {
binlog.Log(cs.ctx, logEntry)
}
}
}
if err == nil {
Expand Down
9 changes: 4 additions & 5 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6339,12 +6339,11 @@ func (s) TestGlobalBinaryLoggingOptions(t *testing.T) {
return &testpb.SimpleResponse{}, nil
},
FullDuplexCallF: func(stream testgrpc.TestService_FullDuplexCallServer) error {
for {
_, err := stream.Recv()
if err == io.EOF {
return nil
}
_, err := stream.Recv()
if err == io.EOF {
return nil
}
return status.Errorf(codes.Unknown, "expected client to call CloseSend")
},
}

Expand Down
20 changes: 17 additions & 3 deletions test/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ func (s) TestRetryStreaming(t *testing.T) {
return nil
}
}
sHdr := func() serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return stream.SendHeader(metadata.Pairs("test_header", "test_value"))
}
}
sRes := func(b byte) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
msg := res(b)
Expand All @@ -222,7 +227,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
sErr := func(c codes.Code) serverOp {
return func(stream testgrpc.TestService_FullDuplexCallServer) error {
return status.New(c, "").Err()
return status.New(c, "this is a test error").Err()
}
}
sCloseSend := func() serverOp {
Expand Down Expand Up @@ -270,7 +275,7 @@ func (s) TestRetryStreaming(t *testing.T) {
}
cErr := func(c codes.Code) clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
want := status.New(c, "").Err()
want := status.New(c, "this is a test error").Err()
if c == codes.OK {
want = io.EOF
}
Expand Down Expand Up @@ -309,6 +314,11 @@ func (s) TestRetryStreaming(t *testing.T) {
cHdr := func() clientOp {
return func(stream testgrpc.TestService_FullDuplexCallClient) error {
_, err := stream.Header()
if err == io.EOF {
// The stream ended successfully; convert to nil to avoid
// erroring the test case.
err = nil
}
return err
}
}
Expand Down Expand Up @@ -362,9 +372,13 @@ func (s) TestRetryStreaming(t *testing.T) {
sReq(1), sRes(3), sErr(codes.Unavailable),
},
clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)},
}, {
desc: "Retry via ClientStream.Header()",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1)},
clientOps: []clientOp{cReq(1), cHdr() /* this should cause a retry */, cErr(codes.OK)},
}, {
desc: "No retry after header",
serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
serverOps: []serverOp{sReq(1), sHdr(), sErr(codes.Unavailable)},
clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)},
}, {
desc: "No retry after context",
Expand Down
Loading