From 0c1d39df28cc4e0369b185df2b33b1094cd543fb Mon Sep 17 00:00:00 2001 From: dfawley Date: Fri, 7 Apr 2017 11:54:56 -0700 Subject: [PATCH] Separate incoming and outgoing metadata in context This will prevent the incoming RPCs' metadata from appearing in outgoing RPCs unless it is explicitly copied, e.g.: incomingMD, ok := metadata.FromContext(ctx) if ok { ctx = metadata.NewContext(ctx, incomingMD) } Fixes #1148 --- Documentation/grpc-metadata.md | 13 ++- grpclb/grpclb_test.go | 2 +- interop/test_utils.go | 10 +- metadata/metadata.go | 38 +++++-- stats/stats_test.go | 8 +- test/end2end_test.go | 201 ++++++++++++++++++++++++++++++--- transport/handler_server.go | 2 +- transport/http2_client.go | 2 +- transport/http2_server.go | 2 +- 9 files changed, 235 insertions(+), 43 deletions(-) diff --git a/Documentation/grpc-metadata.md b/Documentation/grpc-metadata.md index f36ef72af689..c26435269e65 100644 --- a/Documentation/grpc-metadata.md +++ b/Documentation/grpc-metadata.md @@ -66,11 +66,11 @@ md := metadata.Pairs( ## Retrieving metadata from context -Metadata can be retrieved from context using `FromContext`: +Metadata can be retrieved from context using `FromIncomingContext`: ```go func (s *server) SomeRPC(ctx context.Context, in *pb.SomeRequest) (*pb.SomeResponse, err) { - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) // do something with metadata } ``` @@ -88,7 +88,7 @@ To send metadata to server, the client can wrap the metadata into a context usin md := metadata.Pairs("key", "val") // create a new context with this metadata -ctx := metadata.NewContext(context.Background(), md) +ctx := metadata.NewOutgoingContext(context.Background(), md) // make unary RPC response, err := client.SomeRPC(ctx, someRequest) @@ -96,6 +96,9 @@ response, err := client.SomeRPC(ctx, someRequest) // or make streaming RPC stream, err := client.SomeStreamingRPC(ctx) ``` + +To read this back from the context on the client (e.g. in an interceptor) before the RPC is sent, use `FromOutgoingContext`. + ### Receiving metadata Metadata that a client can receive includes header and trailer. @@ -152,7 +155,7 @@ For streaming calls, the server needs to get context from the stream. ```go func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someResponse, error) { - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) // do something with metadata } ``` @@ -161,7 +164,7 @@ func (s *server) SomeRPC(ctx context.Context, in *pb.someRequest) (*pb.someRespo ```go func (s *server) SomeStreamingRPC(stream pb.Service_SomeStreamingRPCServer) error { - md, ok := metadata.FromContext(stream.Context()) // get context from stream + md, ok := metadata.FromIncomingContext(stream.Context()) // get context from stream // do something with metadata } ``` diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 2a62a87f5264..ba7824c2e6c9 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -215,7 +215,7 @@ type helloServer struct { } func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) { - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, grpc.Errorf(codes.Internal, "failed to receive metadata") } diff --git a/interop/test_utils.go b/interop/test_utils.go index e4e427c75e5a..15ec00839565 100644 --- a/interop/test_utils.go +++ b/interop/test_utils.go @@ -392,7 +392,7 @@ func DoPerRPCCreds(tc testpb.TestServiceClient, serviceAccountKeyFile, oauthScop } token := GetToken(serviceAccountKeyFile, oauthScope) kv := map[string]string{"authorization": token.TokenType + " " + token.AccessToken} - ctx := metadata.NewContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}}) + ctx := metadata.NewOutgoingContext(context.Background(), metadata.MD{"authorization": []string{kv["authorization"]}}) reply, err := tc.UnaryCall(ctx, req) if err != nil { grpclog.Fatal("/TestService/UnaryCall RPC failed: ", err) @@ -416,7 +416,7 @@ var ( // DoCancelAfterBegin cancels the RPC after metadata has been sent but before payloads are sent. func DoCancelAfterBegin(tc testpb.TestServiceClient, args ...grpc.CallOption) { - ctx, cancel := context.WithCancel(metadata.NewContext(context.Background(), testMetadata)) + ctx, cancel := context.WithCancel(metadata.NewOutgoingContext(context.Background(), testMetadata)) stream, err := tc.StreamingInputCall(ctx, args...) if err != nil { grpclog.Fatalf("%v.StreamingInputCall(_) = _, %v", tc, err) @@ -491,7 +491,7 @@ func DoCustomMetadata(tc testpb.TestServiceClient, args ...grpc.CallOption) { ResponseSize: proto.Int32(int32(1)), Payload: pl, } - ctx := metadata.NewContext(context.Background(), customMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), customMetadata) var header, trailer metadata.MD args = append(args, grpc.Header(&header), grpc.Trailer(&trailer)) reply, err := tc.UnaryCall( @@ -627,7 +627,7 @@ func serverNewPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { status := in.GetResponseStatus() - if md, ok := metadata.FromContext(ctx); ok { + if md, ok := metadata.FromIncomingContext(ctx); ok { if initialMetadata, ok := md[initialMetadataKey]; ok { header := metadata.Pairs(initialMetadataKey, initialMetadata[0]) grpc.SendHeader(ctx, header) @@ -686,7 +686,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput } func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { - if md, ok := metadata.FromContext(stream.Context()); ok { + if md, ok := metadata.FromIncomingContext(stream.Context()); ok { if initialMetadata, ok := md[initialMetadataKey]; ok { header := metadata.Pairs(initialMetadataKey, initialMetadata[0]) stream.SendHeader(header) diff --git a/metadata/metadata.go b/metadata/metadata.go index 733239502865..7ca44182612f 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -136,17 +136,41 @@ func Join(mds ...MD) MD { return out } -type mdKey struct{} +type mdIncomingKey struct{} +type mdOutgoingKey struct{} -// NewContext creates a new context with md attached. +// NewContext is a wrapper for NewOutgoingContext(ctx, md). Deprecated. func NewContext(ctx context.Context, md MD) context.Context { - return context.WithValue(ctx, mdKey{}, md) + return NewOutgoingContext(ctx, md) } -// FromContext returns the MD in ctx if it exists. -// The returned md should be immutable, writing to it may cause races. -// Modification should be made to the copies of the returned md. +// NewIncomingContext creates a new context with incoming md attached. +func NewIncomingContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdIncomingKey{}, md) +} + +// NewOutgoingContext creates a new context with outgoing md attached. +func NewOutgoingContext(ctx context.Context, md MD) context.Context { + return context.WithValue(ctx, mdOutgoingKey{}, md) +} + +// FromContext is a wrapper for FromIncomingContext(ctx). Deprecated. func FromContext(ctx context.Context) (md MD, ok bool) { - md, ok = ctx.Value(mdKey{}).(MD) + return FromIncomingContext(ctx) +} + +// FromIncomingContext returns the incoming MD in ctx if it exists. The +// returned md should be immutable, writing to it may cause races. +// Modification should be made to the copies of the returned md. +func FromIncomingContext(ctx context.Context) (md MD, ok bool) { + md, ok = ctx.Value(mdIncomingKey{}).(MD) + return +} + +// FromOutgoingContext returns the outgoing MD in ctx if it exists. The +// returned md should be immutable, writing to it may cause races. +// Modification should be made to the copies of the returned md. +func FromOutgoingContext(ctx context.Context) (md MD, ok bool) { + md, ok = ctx.Value(mdOutgoingKey{}).(MD) return } diff --git a/stats/stats_test.go b/stats/stats_test.go index 3e5424beb4b7..6121b432b1a2 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -75,7 +75,7 @@ var ( type testServer struct{} func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) if ok { if err := grpc.SendHeader(ctx, md); err != nil { return nil, grpc.Errorf(grpc.Code(err), "grpc.SendHeader(_, %v) = %v, want ", md, err) @@ -93,7 +93,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { - md, ok := metadata.FromContext(stream.Context()) + md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if err := stream.SendHeader(md); err != nil { return grpc.Errorf(grpc.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) @@ -237,7 +237,7 @@ func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.Simple } else { req = &testpb.SimpleRequest{Id: errorID} } - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) resp, err = tc.UnaryCall(ctx, req, grpc.FailFast(c.failfast)) return req, resp, err @@ -250,7 +250,7 @@ func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest err error ) tc := testpb.NewTestServiceClient(te.clientConn()) - stream, err := tc.FullDuplexCall(metadata.NewContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) + stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.FailFast(c.failfast)) if err != nil { return reqs, resps, err } diff --git a/test/end2end_test.go b/test/end2end_test.go index ae053d5df77e..3acfaea827b2 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -118,7 +118,7 @@ type testServer struct { } func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { - if md, ok := metadata.FromContext(ctx); ok { + if md, ok := metadata.FromIncomingContext(ctx); ok { // For testing purpose, returns an error if user-agent is failAppUA. // To test that client gets the correct error. if ua, ok := md["user-agent"]; !ok || strings.HasPrefix(ua[0], failAppUA) { @@ -152,7 +152,7 @@ func newPayload(t testpb.PayloadType, size int32) (*testpb.Payload, error) { } func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { - md, ok := metadata.FromContext(ctx) + md, ok := metadata.FromIncomingContext(ctx) if ok { if _, exists := md[":authority"]; !exists { return nil, grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md) @@ -223,7 +223,7 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* } func (s *testServer) StreamingOutputCall(args *testpb.StreamingOutputCallRequest, stream testpb.TestService_StreamingOutputCallServer) error { - if md, ok := metadata.FromContext(stream.Context()); ok { + if md, ok := metadata.FromIncomingContext(stream.Context()); ok { if _, exists := md[":authority"]; !exists { return grpc.Errorf(codes.DataLoss, "expected an :authority metadata: %v", md) } @@ -274,7 +274,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput } func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { - md, ok := metadata.FromContext(stream.Context()) + md, ok := metadata.FromIncomingContext(stream.Context()) if ok { if s.setAndSendHeader { if err := stream.SetHeader(md); err != nil { @@ -1385,7 +1385,7 @@ func testFailedEmptyUnary(t *testing.T, e env) { defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) wantErr := detailedError if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) { t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr) @@ -1602,7 +1602,7 @@ func testMetadataUnaryRPC(t *testing.T, e env) { Payload: payload, } var header, trailer metadata.MD - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.Trailer(&trailer)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } @@ -1648,7 +1648,7 @@ func testMultipleSetTrailerUnaryRPC(t *testing.T, e env) { Payload: payload, } var trailer metadata.MD - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) if _, err := tc.UnaryCall(ctx, req, grpc.Trailer(&trailer), grpc.FailFast(false)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } @@ -1671,7 +1671,7 @@ func testMultipleSetTrailerStreamingRPC(t *testing.T, e env) { defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) stream, err := tc.FullDuplexCall(ctx, grpc.FailFast(false)) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1722,7 +1722,7 @@ func testSetAndSendHeaderUnaryRPC(t *testing.T, e env) { Payload: payload, } var header metadata.MD - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } @@ -1766,7 +1766,7 @@ func testMultipleSetHeaderUnaryRPC(t *testing.T, e env) { } var header metadata.MD - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err != nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } @@ -1809,7 +1809,7 @@ func testMultipleSetHeaderUnaryRPCError(t *testing.T, e env) { Payload: payload, } var header metadata.MD - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) if _, err := tc.UnaryCall(ctx, req, grpc.Header(&header), grpc.FailFast(false)); err == nil { t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, ", ctx, err) } @@ -1841,7 +1841,7 @@ func testSetAndSendHeaderStreamingRPC(t *testing.T, e env) { argSize = 1 respSize = 1 ) - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1885,7 +1885,7 @@ func testMultipleSetHeaderStreamingRPC(t *testing.T, e env) { argSize = 1 respSize = 1 ) - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -1949,7 +1949,7 @@ func testMultipleSetHeaderStreamingRPCError(t *testing.T, e env) { argSize = 1 respSize = -1 ) - ctx := metadata.NewContext(context.Background(), testMetadata) + ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -2014,7 +2014,7 @@ func testMalformedHTTP2Metadata(t *testing.T, e env) { ResponseSize: proto.Int32(314), Payload: payload, } - ctx := metadata.NewContext(context.Background(), malformedHTTP2Metadata) + ctx := metadata.NewOutgoingContext(context.Background(), malformedHTTP2Metadata) if _, err := tc.UnaryCall(ctx, req); grpc.Code(err) != codes.Internal { t.Fatalf("TestService.UnaryCall(%v, _) = _, %v; want _, %s", ctx, err, codes.Internal) } @@ -2344,7 +2344,7 @@ func testMetadataStreamingRPC(t *testing.T, e env) { defer te.tearDown() tc := testpb.NewTestServiceClient(te.clientConn()) - ctx := metadata.NewContext(te.ctx, testMetadata) + ctx := metadata.NewOutgoingContext(te.ctx, testMetadata) stream, err := tc.FullDuplexCall(ctx) if err != nil { t.Fatalf("%v.FullDuplexCall(_) = _, %v, want ", tc, err) @@ -2483,7 +2483,7 @@ func testFailedServerStreaming(t *testing.T, e env) { ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(), ResponseParameters: respParam, } - ctx := metadata.NewContext(te.ctx, testMetadata) + ctx := metadata.NewOutgoingContext(te.ctx, testMetadata) stream, err := tc.StreamingOutputCall(ctx, req) if err != nil { t.Fatalf("%v.StreamingOutputCall(_) = _, %v, want ", tc, err) @@ -2887,7 +2887,7 @@ func testCompressOK(t *testing.T, e env) { ResponseSize: proto.Int32(respSize), Payload: payload, } - ctx := metadata.NewContext(context.Background(), metadata.Pairs("something", "something")) + ctx := metadata.NewOutgoingContext(context.Background(), metadata.Pairs("something", "something")) if _, err := tc.UnaryCall(ctx, req); err != nil { t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, ", err) } @@ -3679,3 +3679,168 @@ func (fw *filterWriter) Write(p []byte) (n int, err error) { } return fw.dst.Write(p) } + +// stubServer is a server that is easy to customize within individual test +// cases. +type stubServer struct { + // Guarantees we satisfy this interface; panics if unimplemented methods are called. + testpb.TestServiceServer + + // Customizable implementations of server handlers. + emptyCall func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) + fullDuplexCall func(stream testpb.TestService_FullDuplexCallServer) error + + // A client connected to this service the test may use. Created in Start(). + client testpb.TestServiceClient + + cleanups []func() // Lambdas executed in Stop(); populated by Start(). +} + +func (ss *stubServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + return ss.emptyCall(ctx, in) +} + +func (ss *stubServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { + return ss.fullDuplexCall(stream) +} + +// Start starts the server and creates a client connected to it. +func (ss *stubServer) Start() error { + lis, err := net.Listen("tcp", ":0") + if err != nil { + return fmt.Errorf(`net.Listen("tcp", ":0") = %v`, err) + } + ss.cleanups = append(ss.cleanups, func() { lis.Close() }) + + s := grpc.NewServer() + testpb.RegisterTestServiceServer(s, ss) + go s.Serve(lis) + ss.cleanups = append(ss.cleanups, s.Stop) + + cc, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithBlock()) + if err != nil { + return fmt.Errorf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + ss.cleanups = append(ss.cleanups, func() { cc.Close() }) + + ss.client = testpb.NewTestServiceClient(cc) + return nil +} + +func (ss *stubServer) Stop() { + for i := len(ss.cleanups) - 1; i >= 0; i-- { + ss.cleanups[i]() + } +} + +func TestUnaryProxyDoesNotForwardMetadata(t *testing.T) { + const mdkey = "somedata" + + // endpoint ensures mdkey is NOT in metadata and returns an error if it is. + endpoint := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil { + return nil, status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey) + } + return &testpb.Empty{}, nil + }, + } + if err := endpoint.Start(); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer endpoint.Stop() + + // proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint + // without explicitly copying the metadata. + proxy := &stubServer{ + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil { + return nil, status.Errorf(codes.Internal, "proxy: md=%v; want contains(%q)", md, mdkey) + } + return endpoint.client.EmptyCall(ctx, in) + }, + } + if err := proxy.Start(); err != nil { + t.Fatalf("Error starting proxy server: %v", err) + } + defer proxy.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + md := metadata.Pairs(mdkey, "val") + ctx = metadata.NewOutgoingContext(ctx, md) + + // Sanity check that endpoint properly errors when it sees mdkey. + _, err := endpoint.client.EmptyCall(ctx, &testpb.Empty{}) + if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal { + t.Fatalf("endpoint.client.EmptyCall(_, _) = _, %v; want _, ", err) + } + + if _, err := proxy.client.EmptyCall(ctx, &testpb.Empty{}); err != nil { + t.Fatal(err.Error()) + } +} + +func TestStreamingProxyDoesNotForwardMetadata(t *testing.T) { + const mdkey = "somedata" + + // doFDC performs a FullDuplexCall with client and returns the error from the + // first stream.Recv call, or nil if that error is io.EOF. Calls t.Fatal if + // the stream cannot be established. + doFDC := func(ctx context.Context, client testpb.TestServiceClient) error { + stream, err := client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("Unwanted error: %v", err) + } + if _, err := stream.Recv(); err != io.EOF { + return err + } + return nil + } + + // endpoint ensures mdkey is NOT in metadata and returns an error if it is. + endpoint := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + ctx := stream.Context() + if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] != nil { + return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey) + } + return nil + }, + } + if err := endpoint.Start(); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer endpoint.Stop() + + // proxy ensures mdkey IS in metadata, then forwards the RPC to endpoint + // without explicitly copying the metadata. + proxy := &stubServer{ + fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { + ctx := stream.Context() + if md, ok := metadata.FromIncomingContext(ctx); !ok || md[mdkey] == nil { + return status.Errorf(codes.Internal, "endpoint: md=%v; want !contains(%q)", md, mdkey) + } + return doFDC(ctx, endpoint.client) + }, + } + if err := proxy.Start(); err != nil { + t.Fatalf("Error starting proxy server: %v", err) + } + defer proxy.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + md := metadata.Pairs(mdkey, "val") + ctx = metadata.NewOutgoingContext(ctx, md) + + // Sanity check that endpoint properly errors when it sees mdkey in ctx. + err := doFDC(ctx, endpoint.client) + if s, ok := status.FromError(err); !ok || s.Code() != codes.Internal { + t.Fatalf("stream.Recv() = _, %v; want _, ", err) + } + + if err := doFDC(ctx, proxy.client); err != nil { + t.Fatalf("doFDC(_, proxy.client) = %v; want nil", err) + } +} diff --git a/transport/handler_server.go b/transport/handler_server.go index e1c43f68ef7e..28c9ce03658a 100644 --- a/transport/handler_server.go +++ b/transport/handler_server.go @@ -319,7 +319,7 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream), trace if req.TLS != nil { pr.AuthInfo = credentials.TLSInfo{State: *req.TLS} } - ctx = metadata.NewContext(ctx, ht.headerMD) + ctx = metadata.NewIncomingContext(ctx, ht.headerMD) ctx = peer.NewContext(ctx, pr) s.ctx = newContextWithStream(ctx, s) s.dec = &recvBufferReader{ctx: s.ctx, recv: s.buf} diff --git a/transport/http2_client.go b/transport/http2_client.go index 7d7269890b94..5fc6b75f119a 100644 --- a/transport/http2_client.go +++ b/transport/http2_client.go @@ -432,7 +432,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea hasMD bool endHeaders bool ) - if md, ok := metadata.FromContext(ctx); ok { + if md, ok := metadata.FromOutgoingContext(ctx); ok { hasMD = true for k, v := range md { // HTTP doesn't allow you to set pseudoheaders after non pseudoheaders were set. diff --git a/transport/http2_server.go b/transport/http2_server.go index db72e9403a52..31fefc7bb7cc 100644 --- a/transport/http2_server.go +++ b/transport/http2_server.go @@ -261,7 +261,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func( s.ctx = newContextWithStream(s.ctx, s) // Attach the received metadata to the context. if len(state.mdata) > 0 { - s.ctx = metadata.NewContext(s.ctx, state.mdata) + s.ctx = metadata.NewIncomingContext(s.ctx, state.mdata) } s.dec = &recvBufferReader{