diff --git a/client.go b/client.go index e40592dd7..a90608e8b 100644 --- a/client.go +++ b/client.go @@ -24,6 +24,7 @@ import ( "strings" "sync" "syscall" + "time" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" @@ -86,6 +87,10 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int cresp = &Response{} ) + if dl, ok := ctx.Deadline(); ok { + creq.TimeoutNano = dl.Sub(time.Now()).Nanoseconds() + } + if err := c.dispatch(ctx, creq, cresp); err != nil { return err } @@ -104,6 +109,7 @@ func (c *Client) Call(ctx context.Context, service, method string, req, resp int func (c *Client) dispatch(ctx context.Context, req *Request, resp *Response) error { errs := make(chan error, 1) call := &callRequest{ + ctx: ctx, req: req, resp: resp, errs: errs, diff --git a/server.go b/server.go index 263cb4583..40804eac0 100644 --- a/server.go +++ b/server.go @@ -414,6 +414,9 @@ func (c *serverConn) run(sctx context.Context) { case request := <-requests: active++ go func(id uint32) { + ctx, cancel := getRequestContext(ctx, request.req) + defer cancel() + p, status := c.server.services.call(ctx, request.req.Service, request.req.Method, request.req.Payload) resp := &Response{ Status: status.Proto(), @@ -454,3 +457,15 @@ func (c *serverConn) run(sctx context.Context) { } } } + +var noopFunc = func() {} + +func getRequestContext(ctx context.Context, req *Request) (retCtx context.Context, cancel func()) { + cancel = noopFunc + if req.TimeoutNano == 0 { + return ctx, cancel + } + + ctx, cancel = context.WithTimeout(ctx, time.Duration(req.TimeoutNano)) + return ctx, cancel +} diff --git a/server_test.go b/server_test.go index ec07f0ef2..4f31feba7 100644 --- a/server_test.go +++ b/server_test.go @@ -24,6 +24,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/gogo/protobuf/proto" "github.com/pkg/errors" @@ -57,7 +58,8 @@ func (tc *testingClient) Test(ctx context.Context, req *testPayload) (*testPaylo } type testPayload struct { - Foo string `protobuf:"bytes,1,opt,name=foo,proto3"` + Foo string `protobuf:"bytes,1,opt,name=foo,proto3"` + Deadline int64 `protobuf:"varint,2,opt,name=deadline,proto3"` } func (r *testPayload) Reset() { *r = testPayload{} } @@ -68,7 +70,11 @@ func (r *testPayload) ProtoMessage() {} type testingServer struct{} func (s *testingServer) Test(ctx context.Context, req *testPayload) (*testPayload, error) { - return &testPayload{Foo: strings.Repeat(req.Foo, 2)}, nil + tp := &testPayload{Foo: strings.Repeat(req.Foo, 2)} + if dl, ok := ctx.Deadline(); ok { + tp.Deadline = dl.UnixNano() + } + return tp, nil } // registerTestingService mocks more of what is generated code. Unlike grpc, we @@ -376,6 +382,34 @@ func TestUnixSocketHandshake(t *testing.T) { } } +func TestServerRequestTimeout(t *testing.T) { + var ( + ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(10*time.Minute)) + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + testImpl = &testingServer{} + client, cleanup = newTestClient(t, addr) + result testPayload + ) + defer cancel() + defer cleanup() + defer listener.Close() + + registerTestingService(server, testImpl) + + go server.Serve(ctx, listener) + defer server.Shutdown(ctx) + + if err := client.Call(ctx, serviceName, "Test", &testPayload{}, &result); err != nil { + t.Fatalf("unexpected error making call: %v", err) + } + + dl, _ := ctx.Deadline() + if result.Deadline != dl.UnixNano() { + t.Fatalf("expected deadline %v, actual: %v", dl, result.Deadline) + } +} + func BenchmarkRoundTrip(b *testing.B) { var ( ctx = context.Background() diff --git a/types.go b/types.go index 1f7969e5c..a6b3b818e 100644 --- a/types.go +++ b/types.go @@ -23,9 +23,10 @@ import ( ) type Request struct { - Service string `protobuf:"bytes,1,opt,name=service,proto3"` - Method string `protobuf:"bytes,2,opt,name=method,proto3"` - Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` + Service string `protobuf:"bytes,1,opt,name=service,proto3"` + Method string `protobuf:"bytes,2,opt,name=method,proto3"` + Payload []byte `protobuf:"bytes,3,opt,name=payload,proto3"` + TimeoutNano int64 `protobuf:"varint,4,opt,name=timeout_nano,proto3"` } func (r *Request) Reset() { *r = Request{} }