diff --git a/internal/status/status.go b/internal/status/status.go index 4cf85cad9f81..d1e11739689f 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -43,9 +43,45 @@ type Status struct { s *spb.Status } +// NewWithProto returns a new status including details from statusProto. This +// is meant to be used by the gRPC library only. +func NewWithProto(code codes.Code, message string, statusProto []string) *Status { + if len(statusProto) != 1 { + // No grpc-status-details bin header, or multiple; just ignore. + return &Status{s: &spb.Status{Code: normalizeCode(code), Message: message}} + } + st := &spb.Status{} + if err := proto.Unmarshal([]byte(statusProto[0]), st); err != nil { + // Probably not a google.rpc.Status proto; do not provide details. + return &Status{s: &spb.Status{Code: normalizeCode(code), Message: message}} + } + if st.Code == int32(code) { + // The codes match between the grpc-status header and the + // grpc-status-details-bin header; use the full details proto. + st.Code = normalizeCode(codes.Code(st.Code)) + return &Status{s: st} + } + return &Status{ + s: &spb.Status{ + Code: int32(codes.Internal), + Message: fmt.Sprintf( + "grpc-status-details-bin mismatch: grpc-status=%v, grpc-message=%q, grpc-status-details-bin=%+v", + code, message, st, + ), + }, + } +} + +func normalizeCode(c codes.Code) int32 { + if c > 16 { + return int32(codes.Unknown) + } + return int32(c) +} + // New returns a Status representing c and msg. func New(c codes.Code, msg string) *Status { - return &Status{s: &spb.Status{Code: int32(c), Message: msg}} + return &Status{s: &spb.Status{Code: normalizeCode(c), Message: msg}} } // Newf returns New(c, fmt.Sprintf(format, a...)). diff --git a/internal/stubserver/stubserver.go b/internal/stubserver/stubserver.go index 39c291500547..a83a42fa92c4 100644 --- a/internal/stubserver/stubserver.go +++ b/internal/stubserver/stubserver.go @@ -27,6 +27,7 @@ import ( "testing" "time" + "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials/insecure" @@ -110,8 +111,7 @@ func RegisterServiceServerOption(f func(*grpc.Server)) grpc.ServerOption { return ®isterServiceServerOption{f: f} } -// StartServer only starts the server. It does not create a client to it. -func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error { +func (ss *StubServer) setupServer(sopts ...grpc.ServerOption) (net.Listener, error) { if ss.Network == "" { ss.Network = "tcp" } @@ -127,24 +127,59 @@ func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error { var err error lis, err = net.Listen(ss.Network, ss.Address) if err != nil { - return fmt.Errorf("net.Listen(%q, %q) = %v", ss.Network, ss.Address, err) + return nil, fmt.Errorf("net.Listen(%q, %q) = %v", ss.Network, ss.Address, err) } } ss.Address = lis.Addr().String() - ss.cleanups = append(ss.cleanups, func() { lis.Close() }) - s := grpc.NewServer(sopts...) + ss.S = grpc.NewServer(sopts...) for _, so := range sopts { switch x := so.(type) { case *registerServiceServerOption: - x.f(s) + x.f(ss.S) + } + } + + testgrpc.RegisterTestServiceServer(ss.S, ss) + ss.cleanups = append(ss.cleanups, ss.S.Stop) + return lis, nil +} + +// StartHandlerServer only starts an HTTP server with a gRPC server as the +// handler. It does not create a client to it. Cannot be used in a StubServer +// that also used StartServer. +func (ss *StubServer) StartHandlerServer(sopts ...grpc.ServerOption) error { + lis, err := ss.setupServer(sopts...) + if err != nil { + return err + } + + go func() { + hs := &http2.Server{} + opts := &http2.ServeConnOpts{Handler: ss.S} + for { + conn, err := lis.Accept() + if err != nil { + return + } + hs.ServeConn(conn, opts) } + }() + ss.cleanups = append(ss.cleanups, func() { lis.Close() }) + + return nil +} + +// StartServer only starts the server. It does not create a client to it. +// Cannot be used in a StubServer that also used StartHandlerServer. +func (ss *StubServer) StartServer(sopts ...grpc.ServerOption) error { + lis, err := ss.setupServer(sopts...) + if err != nil { + return err } - testgrpc.RegisterTestServiceServer(s, ss) - go s.Serve(lis) - ss.cleanups = append(ss.cleanups, s.Stop) - ss.S = s + go ss.S.Serve(lis) + return nil } diff --git a/internal/transport/handler_server.go b/internal/transport/handler_server.go index 5e14f934d614..17f7a21b5a9f 100644 --- a/internal/transport/handler_server.go +++ b/internal/transport/handler_server.go @@ -220,18 +220,20 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro h.Set("Grpc-Message", encodeGrpcMessage(m)) } + s.hdrMu.Lock() if p := st.Proto(); p != nil && len(p.Details) > 0 { + delete(s.trailer, grpcStatusDetailsBinHeader) stBytes, err := proto.Marshal(p) if err != nil { // TODO: return error instead, when callers are able to handle it. panic(err) } - h.Set("Grpc-Status-Details-Bin", encodeBinHeader(stBytes)) + h.Set(grpcStatusDetailsBinHeader, encodeBinHeader(stBytes)) } - if md := s.Trailer(); len(md) > 0 { - for k, vv := range md { + if len(s.trailer) > 0 { + for k, vv := range s.trailer { // Clients don't tolerate reading restricted headers after some non restricted ones were sent. if isReservedHeader(k) { continue @@ -243,6 +245,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, st *status.Status) erro } } } + s.hdrMu.Unlock() }) if err == nil { // transport has not been closed @@ -287,7 +290,7 @@ func (ht *serverHandlerTransport) writeCommonHeaders(s *Stream) { } // writeCustomHeaders sets custom headers set on the stream via SetHeader -// on the first write call (Write, WriteHeader, or WriteStatus). +// on the first write call (Write, WriteHeader, or WriteStatus) func (ht *serverHandlerTransport) writeCustomHeaders(s *Stream) { h := ht.rw.Header() diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index badab8acf3b1..d6f5c49358b5 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -1399,7 +1399,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { mdata = make(map[string][]string) contentTypeErr = "malformed header: missing HTTP content-type" grpcMessage string - statusGen *status.Status recvCompress string httpStatusCode *int httpStatusErr string @@ -1434,12 +1433,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { rawStatusCode = codes.Code(uint32(code)) case "grpc-message": grpcMessage = decodeGrpcMessage(hf.Value) - case "grpc-status-details-bin": - var err error - statusGen, err = decodeGRPCStatusDetails(hf.Value) - if err != nil { - headerError = fmt.Sprintf("transport: malformed grpc-status-details-bin: %v", err) - } case ":status": if hf.Value == "200" { httpStatusErr = "" @@ -1548,14 +1541,12 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } - if statusGen == nil { - statusGen = status.New(rawStatusCode, grpcMessage) - } + status := istatus.NewWithProto(rawStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader]) // If client received END_STREAM from server while stream was still active, // send RST_STREAM. rstStream := s.getState() == streamActive - t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, statusGen, mdata, true) + t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, status, mdata, true) } // readServerPreface reads and handles the initial settings frame from the diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index cadc64043f47..2b1ce0148e3c 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -1057,12 +1057,15 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error { headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())}) if p := st.Proto(); p != nil && len(p.Details) > 0 { + // Do not use the user's grpc-status-details-bin (if present) if we are + // even attempting to set our own. + delete(s.trailer, grpcStatusDetailsBinHeader) stBytes, err := proto.Marshal(p) if err != nil { // TODO: return error instead, when callers are able to handle it. t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err) } else { - headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status-details-bin", Value: encodeBinHeader(stBytes)}) + headerFields = append(headerFields, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)}) } } diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 1958140082b3..dc29d590e91f 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -34,12 +34,9 @@ import ( "time" "unicode/utf8" - "github.com/golang/protobuf/proto" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" - spb "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) const ( @@ -88,6 +85,8 @@ var ( } ) +var grpcStatusDetailsBinHeader = "grpc-status-details-bin" + // isReservedHeader checks whether hdr belongs to HTTP2 headers // reserved by gRPC protocol. Any other headers are classified as the // user-specified metadata. @@ -103,7 +102,6 @@ func isReservedHeader(hdr string) bool { "grpc-message", "grpc-status", "grpc-timeout", - "grpc-status-details-bin", // Intentionally exclude grpc-previous-rpc-attempts and // grpc-retry-pushback-ms, which are "reserved", but their API // intentionally works via metadata. @@ -154,18 +152,6 @@ func decodeMetadataHeader(k, v string) (string, error) { return v, nil } -func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) { - v, err := decodeBinHeader(rawDetails) - if err != nil { - return nil, err - } - st := &spb.Status{} - if err = proto.Unmarshal(v, st); err != nil { - return nil, err - } - return status.FromProto(st), nil -} - type timeoutUnit uint8 const ( diff --git a/status/status_ext_test.go b/status/status_ext_test.go index 33c8c71a0062..0d712a69b2ab 100644 --- a/status/status_ext_test.go +++ b/status/status_ext_test.go @@ -19,17 +19,27 @@ package status_test import ( + "context" "errors" + "strings" "testing" + "time" "github.com/golang/protobuf/proto" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" testpb "google.golang.org/grpc/interop/grpc_testing" ) +const defaultTestTimeout = 10 * time.Second + type s struct { grpctest.Tester } @@ -80,3 +90,157 @@ func (s) TestErrorIs(t *testing.T) { } } } + +// TestStatusDetails tests how gRPC handles grpc-status-details-bin, especially +// in cases where it doesn't match the grpc-status trailer or contains arbitrary +// data. +func (s) TestStatusDetails(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for _, serverType := range []struct { + name string + startServerFunc func(*stubserver.StubServer) error + }{{ + name: "normal server", + startServerFunc: func(ss *stubserver.StubServer) error { + return ss.StartServer() + }, + }, { + name: "handler server", + startServerFunc: func(ss *stubserver.StubServer) error { + return ss.StartHandlerServer() + }, + }} { + t.Run(serverType.name, func(t *testing.T) { + // Convenience function for making a status including details. + detailErr := func(c codes.Code, m string) error { + s, err := status.New(c, m).WithDetails(&testpb.SimpleRequest{ + Payload: &testpb.Payload{Body: []byte("detail msg")}, + }) + if err != nil { + t.Fatalf("Error adding details: %v", err) + } + return s.Err() + } + + serialize := func(err error) string { + buf, _ := proto.Marshal(status.Convert(err).Proto()) + return string(buf) + } + + testCases := []struct { + name string + trailerSent metadata.MD + errSent error + trailerWant []string + errWant error + errContains error + }{{ + name: "basic without details", + trailerSent: metadata.MD{}, + errSent: status.Error(codes.Aborted, "test msg"), + errWant: status.Error(codes.Aborted, "test msg"), + }, { + name: "basic without details passes through trailers", + trailerSent: metadata.MD{"grpc-status-details-bin": []string{"random text"}}, + errSent: status.Error(codes.Aborted, "test msg"), + trailerWant: []string{"random text"}, + errWant: status.Error(codes.Aborted, "test msg"), + }, { + name: "basic without details conflicts with manual details", + trailerSent: metadata.MD{"grpc-status-details-bin": []string{serialize(status.Error(codes.Canceled, "test msg"))}}, + errSent: status.Error(codes.Aborted, "test msg"), + trailerWant: []string{serialize(status.Error(codes.Canceled, "test msg"))}, + errContains: status.Error(codes.Internal, "mismatch"), + }, { + name: "basic with details", + trailerSent: metadata.MD{}, + errSent: detailErr(codes.Aborted, "test msg"), + trailerWant: []string{serialize(detailErr(codes.Aborted, "test msg"))}, + errWant: detailErr(codes.Aborted, "test msg"), + }, { + name: "basic with details discards user's trailers", + trailerSent: metadata.MD{"grpc-status-details-bin": []string{"will be ignored"}}, + errSent: detailErr(codes.Aborted, "test msg"), + trailerWant: []string{serialize(detailErr(codes.Aborted, "test msg"))}, + errWant: detailErr(codes.Aborted, "test msg"), + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Start a simple server that returns the trailer and error it receives from + // channels. + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + grpc.SetTrailer(ctx, tc.trailerSent) + return nil, tc.errSent + }, + } + if err := serverType.startServerFunc(ss); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + if err := ss.StartClient(); err != nil { + t.Fatalf("Error starting endpoint client: %v", err) + } + defer ss.Stop() + + trailerGot := metadata.MD{} + _, errGot := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}, grpc.Trailer(&trailerGot)) + gsdb := trailerGot["grpc-status-details-bin"] + if !cmp.Equal(gsdb, tc.trailerWant) { + t.Errorf("Trailer got: %v; want: %v", gsdb, tc.trailerWant) + } + if tc.errWant != nil && !testutils.StatusErrEqual(errGot, tc.errWant) { + t.Errorf("Err got: %v; want: %v", errGot, tc.errWant) + } + if tc.errContains != nil && (status.Code(errGot) != status.Code(tc.errContains) || !strings.Contains(status.Convert(errGot).Message(), status.Convert(tc.errContains).Message())) { + t.Errorf("Err got: %v; want: (Contains: %v)", errGot, tc.errWant) + } + }) + } + }) + } +} + +// TestStatusCodeCollapse ensures that status codes are collapsed to UNKNOWN +// when they are out of the valid range. +func (s) TestStatusCodeCollapse(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + for _, serverType := range []struct { + name string + startServerFunc func(*stubserver.StubServer) error + }{{ + name: "normal server", + startServerFunc: func(ss *stubserver.StubServer) error { + return ss.StartServer() + }, + }, { + name: "handler server", + startServerFunc: func(ss *stubserver.StubServer) error { + return ss.StartHandlerServer() + }, + }} { + t.Run(serverType.name, func(t *testing.T) { + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return nil, status.Errorf(codes.Code(23), "test msg") + }, + } + if err := serverType.startServerFunc(ss); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + if err := ss.StartClient(); err != nil { + t.Fatalf("Error starting endpoint client: %v", err) + } + defer ss.Stop() + + errWant := status.Errorf(codes.Unknown, "test msg") + if _, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{}); !testutils.StatusErrEqual(err, errWant) { + t.Errorf("Err got: %v; want: %v", err, errWant) + } + }) + } +} diff --git a/status/status_test.go b/status/status_test.go index d21a862f3637..10243e285a48 100644 --- a/status/status_test.go +++ b/status/status_test.go @@ -268,6 +268,14 @@ func (s) TestCodeUnknownError(t *testing.T) { } } +func (s) TestCodeOutOfRangeBecomesUnknown(t *testing.T) { + const code, codeWant = codes.Code(20), codes.Unknown + err := fmt.Errorf("wrapped: %w", Error(code, "test description")) + if s := Code(err); s != codeWant { + t.Fatalf("Code(%v) = %v; want ", err, s, codeWant) + } +} + func (s) TestCodeWrapped(t *testing.T) { const code = codes.Internal err := fmt.Errorf("wrapped: %w", Error(code, "test description")) diff --git a/test/balancer_test.go b/test/balancer_test.go index 0c71da7146d6..f02a9d93ac46 100644 --- a/test/balancer_test.go +++ b/test/balancer_test.go @@ -246,7 +246,7 @@ func testDoneInfo(t *testing.T, e env) { defer cancel() wantErr := detailedError if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) { - t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) + t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", status.Convert(err).Proto(), status.Convert(wantErr).Proto()) } if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err)