diff --git a/proxy/handler.go b/proxy/handler.go index bacf7d8..f66abef 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/transport" + "golang.org/x/net/context" ) var ( @@ -64,21 +65,28 @@ func (s *handler) handler(srv interface{}, serverStream grpc.ServerStream) error return grpc.Errorf(codes.Internal, "lowLevelServerStream not exists in context") } fullMethodName := lowLevelServerStream.Method() + clientCtx, clientCancel := context.WithCancel(serverStream.Context()) backendConn, err := s.director(serverStream.Context(), fullMethodName) if err != nil { return err } // TODO(mwitkow): Add a `forwarded` header to metadata, https://en.wikipedia.org/wiki/X-Forwarded-For. - clientStream, err := grpc.NewClientStream(serverStream.Context(), clientStreamDescForProxying, backendConn, fullMethodName) + clientStream, err := grpc.NewClientStream(clientCtx, clientStreamDescForProxying, backendConn, fullMethodName) if err != nil { return err } - defer clientStream.CloseSend() // always close this! - s2cErr := <-s.forwardServerToClient(serverStream, clientStream) - c2sErr := <-s.forwardClientToServer(clientStream, serverStream) + + s2cErrChan := s.forwardServerToClient(serverStream, clientStream) + c2sErrChan := s.forwardClientToServer(clientStream, serverStream) + s2cErr := <-s2cErrChan if s2cErr != io.EOF { + clientCancel() return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr) + } else { + clientStream.CloseSend() } + c2sErr := <-c2sErrChan + serverStream.SetTrailer(clientStream.Trailer()) // c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error. if c2sErr != io.EOF { diff --git a/proxy/handler_test.go b/proxy/handler_test.go index 96c42ce..bbd51ef 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -22,6 +22,8 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" + "fmt" + pb "github.com/mwitkow/grpc-proxy/testservice" ) @@ -72,6 +74,27 @@ func (s *assertingService) PingList(ping *pb.PingRequest, stream pb.TestService_ return nil } +func (s *assertingService) PingStream(stream pb.TestService_PingStreamServer) error { + stream.SendHeader(metadata.Pairs(serverHeaderMdKey, "I like turtles.")) + counter := int32(0) + for { + ping, err := stream.Recv() + if err == io.EOF { + break + } else if err != nil { + require.NoError(s.t, err, "can't fail reading stream") + return err + } + pong := &pb.PingResponse{Value: ping.Value, Counter: counter} + if err := stream.Send(pong); err != nil { + require.NoError(s.t, err, "can't fail sending back a pong") + } + counter += 1 + } + stream.SetTrailer(metadata.Pairs(serverTrailerMdKey, "I like ending turtles.")) + return nil +} + // ProxyHappySuite tests the "happy" path of handling: that everything works in absence of connection issues. type ProxyHappySuite struct { suite.Suite @@ -125,24 +148,28 @@ func (s *ProxyHappySuite) TestDirectorErrorIsPropagated() { assert.Equal(s.T(), "testing rejection", grpc.ErrorDesc(err)) } -func (s *ProxyHappySuite) TestPingListStreamsAll() { - stream, err := s.testClient.PingList(s.ctx(), &pb.PingRequest{Value: "foo"}) - require.NoError(s.T(), err, "PingList request should be successful.") - // Check that the header arrives before all entries. - headerMd, err := stream.Header() - require.NoError(s.T(), err, "PingList headers should not error.") - assert.Len(s.T(), headerMd, 1, "PingList response headers user contain metadata") - count := 0 - for { +func (s *ProxyHappySuite) TestPingStream_FullDuplexWorks() { + stream, err := s.testClient.PingStream(s.ctx()) + require.NoError(s.T(), err, "PingStream request should be successful.") + + for i := 0; i < countListResponses; i++ { + ping := &pb.PingRequest{Value: fmt.Sprintf("foo:%d", i)} + require.NoError(s.T(), stream.Send(ping), "sending to PingStream must not fail") resp, err := stream.Recv() if err == io.EOF { break } - require.NoError(s.T(), err, "PingList stream should not be interrupted.") - require.Equal(s.T(), "foo", resp.Value) - count = count + 1 + if i == 0 { + // Check that the header arrives before all entries. + headerMd, err := stream.Header() + require.NoError(s.T(), err, "PingStream headers should not error.") + assert.Len(s.T(), headerMd, 1, "PingStream response headers user contain metadata") + } + assert.EqualValues(s.T(), i, resp.Counter, "ping roundtrip must succeed with the correct id") } - assert.Equal(s.T(), countListResponses, count, "PingList must successfully return all outputs") + require.NoError(s.T(), stream.CloseSend(), "no error on close send") + _, err = stream.Recv() + require.Equal(s.T(), io.EOF, err, "stream should close with io.EOF, meaining OK") // Check that the trailer headers are here. trailerMd := stream.Trailer() assert.Len(s.T(), trailerMd, 1, "PingList trailer headers user contain metadata") diff --git a/testservice/test.pb.go b/testservice/test.pb.go index acc40a2..1f1f482 100644 --- a/testservice/test.pb.go +++ b/testservice/test.pb.go @@ -60,7 +60,7 @@ func (m *PingRequest) GetValue() string { } type PingResponse struct { - Value string `protobuf:"bytes,1,opt,name=Value,json=value" json:"Value,omitempty"` + Value string `protobuf:"bytes,1,opt,name=Value" json:"Value,omitempty"` Counter int32 `protobuf:"varint,2,opt,name=counter" json:"counter,omitempty"` } @@ -104,6 +104,7 @@ type TestServiceClient interface { Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) PingError(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*Empty, error) PingList(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (TestService_PingListClient, error) + PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) } type testServiceClient struct { @@ -173,6 +174,37 @@ func (x *testServicePingListClient) Recv() (*PingResponse, error) { return m, nil } +func (c *testServiceClient) PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) { + stream, err := grpc.NewClientStream(ctx, &_TestService_serviceDesc.Streams[1], c.cc, "/mwitkow.testproto.TestService/PingStream", opts...) + if err != nil { + return nil, err + } + x := &testServicePingStreamClient{stream} + return x, nil +} + +type TestService_PingStreamClient interface { + Send(*PingRequest) error + Recv() (*PingResponse, error) + grpc.ClientStream +} + +type testServicePingStreamClient struct { + grpc.ClientStream +} + +func (x *testServicePingStreamClient) Send(m *PingRequest) error { + return x.ClientStream.SendMsg(m) +} + +func (x *testServicePingStreamClient) Recv() (*PingResponse, error) { + m := new(PingResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + // Server API for TestService service type TestServiceServer interface { @@ -180,6 +212,7 @@ type TestServiceServer interface { Ping(context.Context, *PingRequest) (*PingResponse, error) PingError(context.Context, *PingRequest) (*Empty, error) PingList(*PingRequest, TestService_PingListServer) error + PingStream(TestService_PingStreamServer) error } func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { @@ -261,6 +294,32 @@ func (x *testServicePingListServer) Send(m *PingResponse) error { return x.ServerStream.SendMsg(m) } +func _TestService_PingStream_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(TestServiceServer).PingStream(&testServicePingStreamServer{stream}) +} + +type TestService_PingStreamServer interface { + Send(*PingResponse) error + Recv() (*PingRequest, error) + grpc.ServerStream +} + +type testServicePingStreamServer struct { + grpc.ServerStream +} + +func (x *testServicePingStreamServer) Send(m *PingResponse) error { + return x.ServerStream.SendMsg(m) +} + +func (x *testServicePingStreamServer) Recv() (*PingRequest, error) { + m := new(PingRequest) + if err := x.ServerStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + var _TestService_serviceDesc = grpc.ServiceDesc{ ServiceName: "mwitkow.testproto.TestService", HandlerType: (*TestServiceServer)(nil), @@ -284,6 +343,12 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ Handler: _TestService_PingList_Handler, ServerStreams: true, }, + { + StreamName: "PingStream", + Handler: _TestService_PingStream_Handler, + ServerStreams: true, + ClientStreams: true, + }, }, Metadata: "test.proto", } @@ -291,19 +356,20 @@ var _TestService_serviceDesc = grpc.ServiceDesc{ func init() { proto.RegisterFile("test.proto", fileDescriptor0) } var fileDescriptor0 = []byte{ - // 218 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e, - 0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x12, 0xcc, 0x2d, 0xcf, 0x2c, 0xc9, 0xce, 0x2f, 0xd7, - 0x03, 0x89, 0x81, 0x85, 0x94, 0xd8, 0xb9, 0x58, 0x5d, 0x73, 0x0b, 0x4a, 0x2a, 0x95, 0x94, 0xb9, - 0xb8, 0x03, 0x32, 0xf3, 0xd2, 0x83, 0x52, 0x0b, 0x4b, 0x53, 0x8b, 0x4b, 0x84, 0x44, 0xb8, 0x58, - 0xcb, 0x12, 0x73, 0x4a, 0x53, 0x25, 0x18, 0x15, 0x18, 0x35, 0x38, 0x83, 0x20, 0x1c, 0x25, 0x3b, - 0x2e, 0x1e, 0x88, 0xa2, 0xe2, 0x82, 0xfc, 0xbc, 0xe2, 0x54, 0x90, 0xaa, 0x30, 0x0c, 0x55, 0x42, - 0x12, 0x5c, 0xec, 0xc9, 0xf9, 0xa5, 0x79, 0x25, 0xa9, 0x45, 0x12, 0x4c, 0x0a, 0x8c, 0x1a, 0xac, - 0x41, 0x30, 0xae, 0xd1, 0x1e, 0x26, 0x2e, 0xee, 0x90, 0xd4, 0xe2, 0x92, 0xe0, 0xd4, 0xa2, 0xb2, - 0xcc, 0xe4, 0x54, 0x21, 0x0f, 0x2e, 0x4e, 0x90, 0x79, 0x60, 0x17, 0x08, 0x49, 0xe8, 0x61, 0x38, - 0x4f, 0x0f, 0x2c, 0x23, 0x25, 0x8f, 0x45, 0x06, 0xd9, 0x1d, 0x4a, 0x0c, 0x42, 0x9e, 0x5c, 0x2c, - 0x20, 0x11, 0x21, 0x39, 0x9c, 0x4a, 0xc1, 0xfe, 0x22, 0xc6, 0x28, 0x77, 0xa8, 0xa3, 0x8a, 0x8a, - 0xf2, 0x8b, 0x08, 0x9a, 0x87, 0xd3, 0xd1, 0x4a, 0x0c, 0x42, 0xfe, 0x5c, 0x1c, 0x20, 0xa5, 0x3e, - 0x99, 0xc5, 0x25, 0x54, 0x70, 0x97, 0x01, 0x63, 0x12, 0x1b, 0x58, 0xdc, 0x18, 0x10, 0x00, 0x00, - 0xff, 0xff, 0x7b, 0xc9, 0x16, 0xf1, 0xd4, 0x01, 0x00, 0x00, + // 237 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xac, 0x8f, 0x31, 0x4b, 0xc4, 0x40, + 0x10, 0x85, 0x6f, 0xd5, 0x78, 0xde, 0x9c, 0x8d, 0x83, 0xc5, 0x62, 0xa1, 0xc7, 0xda, 0xa4, 0x5a, + 0x0e, 0xed, 0xed, 0x44, 0x05, 0x41, 0x49, 0xc4, 0xfe, 0x0c, 0x83, 0x2c, 0x9a, 0x6c, 0xdc, 0x9d, + 0x24, 0xf8, 0x33, 0xfc, 0xc7, 0xb2, 0x1b, 0x85, 0x80, 0x06, 0x2d, 0x52, 0xce, 0x7b, 0x1f, 0x8f, + 0x6f, 0x00, 0x98, 0x3c, 0xeb, 0xda, 0x59, 0xb6, 0x78, 0x50, 0x76, 0x86, 0x5f, 0x6c, 0xa7, 0x43, + 0x16, 0x23, 0x35, 0x87, 0xe4, 0xb2, 0xac, 0xf9, 0x5d, 0x9d, 0xc2, 0xf2, 0xde, 0x54, 0xcf, 0x19, + 0xbd, 0x35, 0xe4, 0x19, 0x0f, 0x21, 0x69, 0x37, 0xaf, 0x0d, 0x49, 0xb1, 0x12, 0xe9, 0x22, 0xeb, + 0x0f, 0x75, 0x01, 0xfb, 0x3d, 0xe4, 0x6b, 0x5b, 0x79, 0x0a, 0xd4, 0xe3, 0x90, 0x8a, 0x07, 0x4a, + 0x98, 0x17, 0xb6, 0xa9, 0x98, 0x9c, 0xdc, 0x5a, 0x89, 0x34, 0xc9, 0xbe, 0xcf, 0xb3, 0x8f, 0x6d, + 0x58, 0x3e, 0x90, 0xe7, 0x9c, 0x5c, 0x6b, 0x0a, 0xc2, 0x6b, 0x58, 0x84, 0xbd, 0x68, 0x80, 0x52, + 0xff, 0xd0, 0xd3, 0xb1, 0x39, 0x3a, 0xf9, 0xa5, 0x19, 0x7a, 0xa8, 0x19, 0xde, 0xc0, 0x4e, 0x48, + 0xf0, 0x78, 0x14, 0x8d, 0x7f, 0xfd, 0x67, 0xea, 0xea, 0x4b, 0xca, 0x39, 0xeb, 0xfe, 0xdc, 0x1b, + 0x95, 0x56, 0x33, 0xbc, 0x83, 0xbd, 0x80, 0xde, 0x1a, 0xcf, 0x13, 0x78, 0xad, 0x05, 0xe6, 0x00, + 0x21, 0xcb, 0xd9, 0xd1, 0xa6, 0x9c, 0x60, 0x32, 0x15, 0x6b, 0xf1, 0xb4, 0x1b, 0x9b, 0xf3, 0xcf, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x4a, 0xc0, 0x8e, 0xe7, 0x29, 0x02, 0x00, 0x00, } diff --git a/testservice/test.proto b/testservice/test.proto index 3ee34d0..54e3cf5 100644 --- a/testservice/test.proto +++ b/testservice/test.proto @@ -22,5 +22,8 @@ service TestService { rpc PingError(PingRequest) returns (Empty) {} rpc PingList(PingRequest) returns (stream PingResponse) {} + + rpc PingStream(stream PingRequest) returns (stream PingResponse) {} + }