From 3dc9fd8b90b164f8f38aa64776d49f6e6b79e913 Mon Sep 17 00:00:00 2001 From: bufdev Date: Thu, 17 Sep 2020 18:19:24 -0400 Subject: [PATCH 1/6] Add server interceptor logic --- Makefile | 2 +- .../clientcompat/clientcompat.twirp.go | 102 +++++++++++++++++- example/service.twirp.go | 52 ++++++++- .../empty_service/empty_service.twirp.go | 2 + .../twirptest/gogo_compat/service.twirp.go | 52 ++++++++- .../google_protobuf_imports/service.twirp.go | 52 ++++++++- .../twirptest/importable/importable.twirp.go | 52 ++++++++- internal/twirptest/importer/importer.twirp.go | 52 ++++++++- .../importer_local/importer_local.twirp.go | 52 ++++++++- internal/twirptest/importmapping/x/x.twirp.go | 52 ++++++++- .../json_serialization.twirp.go | 52 ++++++++- .../twirptest/multiple/multiple1.twirp.go | 52 ++++++++- .../twirptest/multiple/multiple2.twirp.go | 102 +++++++++++++++++- .../no_package_name/no_package_name.twirp.go | 52 ++++++++- .../no_package_name_importer.twirp.go | 52 ++++++++- internal/twirptest/proto/proto.twirp.go | 52 ++++++++- internal/twirptest/service.twirp.go | 52 ++++++++- .../service_method_same_name.twirp.go | 52 ++++++++- internal/twirptest/service_test.go | 43 ++++++++ .../snake_case_names.twirp.go | 52 ++++++++- .../source_relative/source_relative.twirp.go | 52 ++++++++- middleware.go | 44 ++++++++ middleware_test.go | 71 ++++++++++++ protoc-gen-twirp/generator.go | 42 +++++++- server_options.go | 8 ++ 25 files changed, 1205 insertions(+), 43 deletions(-) create mode 100644 middleware.go create mode 100644 middleware_test.go diff --git a/Makefile b/Makefile index 1e0e1d82..ce105ff9 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ test_python_client: generate build/clientcompat build/pycompat setup: ./install_proto.bash - GOPATH=$(CURDIR)/_tools go install github.com/twitchtv/retool/... + GO111MODULE=off GOPATH=$(CURDIR)/_tools GOBIN=$(CURDIR)/_tools/bin go get github.com/twitchtv/retool $(RETOOL) build release_gen: diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index e17add52..30744583 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -205,6 +205,7 @@ func (c *compatServiceJSONClient) NoopMethod(ctx context.Context, in *Empty) (*E type compatServiceServer struct { CompatService + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -231,6 +232,7 @@ func NewCompatServiceServer(svc CompatService, opts ...interface{}) TwirpServer return &compatServiceServer{ CompatService: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -328,11 +330,34 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res return } + handler := s.CompatService.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *Req) (*Resp, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Req) + if !ok { + return nil, twirp.InternalError("could not convert to a *Req") + } + return s.CompatService.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Resp) + if !ok { + return nil, twirp.InternalError("could not convert to a *Resp") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Resp func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.CompatService.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -387,11 +412,34 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http return } + handler := s.CompatService.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *Req) (*Resp, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Req) + if !ok { + return nil, twirp.InternalError("could not convert to a *Req") + } + return s.CompatService.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Resp) + if !ok { + return nil, twirp.InternalError("could not convert to a *Resp") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Resp func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.CompatService.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -457,11 +505,34 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http return } + handler := s.CompatService.NoopMethod + if s.interceptor != nil { + handler = func(ctx context.Context, req *Empty) (*Empty, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *Empty") + } + return s.CompatService.NoopMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *Empty") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Empty func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.CompatService.NoopMethod(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -516,11 +587,34 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp return } + handler := s.CompatService.NoopMethod + if s.interceptor != nil { + handler = func(ctx context.Context, req *Empty) (*Empty, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *Empty") + } + return s.CompatService.NoopMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *Empty") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Empty func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.CompatService.NoopMethod(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/example/service.twirp.go b/example/service.twirp.go index 64db0f50..477295b3 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -163,6 +163,7 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er type haberdasherServer struct { Haberdasher + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -189,6 +190,7 @@ func NewHaberdasherServer(svc Haberdasher, opts ...interface{}) TwirpServer { return &haberdasherServer{ Haberdasher: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -283,11 +285,34 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp return } + handler := s.Haberdasher.MakeHat + if s.interceptor != nil { + handler = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("could not convert to a *Size") + } + return s.Haberdasher.MakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("could not convert to a *Hat") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Hat func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Haberdasher.MakeHat(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -342,11 +367,34 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. return } + handler := s.Haberdasher.MakeHat + if s.interceptor != nil { + handler = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("could not convert to a *Size") + } + return s.Haberdasher.MakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("could not convert to a *Hat") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Hat func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Haberdasher.MakeHat(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/empty_service/empty_service.twirp.go b/internal/twirptest/empty_service/empty_service.twirp.go index de311129..034ca951 100644 --- a/internal/twirptest/empty_service/empty_service.twirp.go +++ b/internal/twirptest/empty_service/empty_service.twirp.go @@ -110,6 +110,7 @@ func NewEmptyJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientO type emptyServer struct { Empty + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -136,6 +137,7 @@ func NewEmptyServer(svc Empty, opts ...interface{}) TwirpServer { return &emptyServer{ Empty: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index 1186e775..8be65da9 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -165,6 +165,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -191,6 +192,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -285,11 +287,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -344,11 +369,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/google_protobuf_imports/service.twirp.go b/internal/twirptest/google_protobuf_imports/service.twirp.go index 0d95abb7..fc86467d 100644 --- a/internal/twirptest/google_protobuf_imports/service.twirp.go +++ b/internal/twirptest/google_protobuf_imports/service.twirp.go @@ -164,6 +164,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *google_protobuf1.StringVal type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -190,6 +191,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -284,11 +286,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*google_protobuf1.StringValue) + if !ok { + return nil, twirp.InternalError("could not convert to a *google_protobuf1.StringValue") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*google_protobuf.Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *google_protobuf.Empty") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *google_protobuf.Empty func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -343,11 +368,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*google_protobuf1.StringValue) + if !ok { + return nil, twirp.InternalError("could not convert to a *google_protobuf1.StringValue") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*google_protobuf.Empty) + if !ok { + return nil, twirp.InternalError("could not convert to a *google_protobuf.Empty") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *google_protobuf.Empty func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index 2e6f8b9a..83584dec 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -164,6 +164,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -190,6 +191,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -284,11 +286,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -343,11 +368,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index f6dadbca..940d4eea 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -166,6 +166,7 @@ func (c *svc2JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_ type svc2Server struct { Svc2 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -192,6 +193,7 @@ func NewSvc2Server(svc Svc2, opts ...interface{}) TwirpServer { return &svc2Server{ Svc2: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -286,11 +288,34 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter return } + handler := s.Svc2.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + } + return s.Svc2.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *twirp_internal_twirptest_importable.Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -345,11 +370,34 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } + handler := s.Svc2.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + } + return s.Svc2.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *twirp_internal_twirptest_importable.Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/importer_local/importer_local.twirp.go b/internal/twirptest/importer_local/importer_local.twirp.go index 8c528ebe..55a627f8 100644 --- a/internal/twirptest/importer_local/importer_local.twirp.go +++ b/internal/twirptest/importer_local/importer_local.twirp.go @@ -161,6 +161,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -187,6 +188,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -281,11 +283,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -340,11 +365,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/importmapping/x/x.twirp.go b/internal/twirptest/importmapping/x/x.twirp.go index 04200d6f..023006b5 100644 --- a/internal/twirptest/importmapping/x/x.twirp.go +++ b/internal/twirptest/importmapping/x/x.twirp.go @@ -163,6 +163,7 @@ func (c *svc1JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_ type svc1Server struct { Svc1 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -189,6 +190,7 @@ func NewSvc1Server(svc Svc1, opts ...interface{}) TwirpServer { return &svc1Server{ Svc1: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -283,11 +285,34 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter return } + handler := s.Svc1.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + } + return s.Svc1.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *twirp_internal_twirptest_importmapping_y.MsgY func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc1.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -342,11 +367,34 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } + handler := s.Svc1.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + } + return s.Svc1.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *twirp_internal_twirptest_importmapping_y.MsgY func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc1.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/json_serialization/json_serialization.twirp.go b/internal/twirptest/json_serialization/json_serialization.twirp.go index 06ca541c..e1940138 100644 --- a/internal/twirptest/json_serialization/json_serialization.twirp.go +++ b/internal/twirptest/json_serialization/json_serialization.twirp.go @@ -161,6 +161,7 @@ func (c *jSONSerializationJSONClient) EchoJSON(ctx context.Context, in *Msg) (*M type jSONSerializationServer struct { JSONSerialization + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -187,6 +188,7 @@ func NewJSONSerializationServer(svc JSONSerialization, opts ...interface{}) Twir return &jSONSerializationServer{ JSONSerialization: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -281,11 +283,34 @@ func (s *jSONSerializationServer) serveEchoJSONJSON(ctx context.Context, resp ht return } + handler := s.JSONSerialization.EchoJSON + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.JSONSerialization.EchoJSON(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.JSONSerialization.EchoJSON(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -340,11 +365,34 @@ func (s *jSONSerializationServer) serveEchoJSONProtobuf(ctx context.Context, res return } + handler := s.JSONSerialization.EchoJSON + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.JSONSerialization.EchoJSON(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.JSONSerialization.EchoJSON(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index 836d5026..80912310 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -165,6 +165,7 @@ func (c *svc1JSONClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { type svc1Server struct { Svc1 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -191,6 +192,7 @@ func NewSvc1Server(svc Svc1, opts ...interface{}) TwirpServer { return &svc1Server{ Svc1: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -285,11 +287,34 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter return } + handler := s.Svc1.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return s.Svc1.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc1.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -344,11 +369,34 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } + handler := s.Svc1.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return s.Svc1.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc1.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index 590b2122..2a7b9570 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -192,6 +192,7 @@ func (c *svc2JSONClient) SamePackageProtoImport(ctx context.Context, in *Msg1) ( type svc2Server struct { Svc2 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -218,6 +219,7 @@ func NewSvc2Server(svc Svc2, opts ...interface{}) TwirpServer { return &svc2Server{ Svc2: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -315,11 +317,34 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter return } + handler := s.Svc2.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg2) (*Msg2, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg2) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg2") + } + return s.Svc2.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg2) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg2") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg2 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -374,11 +399,34 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr return } + handler := s.Svc2.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg2) (*Msg2, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg2) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg2") + } + return s.Svc2.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg2) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg2") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg2 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -444,11 +492,34 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h return } + handler := s.Svc2.SamePackageProtoImport + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return s.Svc2.SamePackageProtoImport(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.SamePackageProtoImport(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -503,11 +574,34 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re return } + handler := s.Svc2.SamePackageProtoImport + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return s.Svc2.SamePackageProtoImport(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.SamePackageProtoImport(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 0cf6daf7..0b6e8579 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -161,6 +161,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -187,6 +188,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -281,11 +283,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -340,11 +365,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index 1522e0fb..281293cf 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -163,6 +163,7 @@ func (c *svc2JSONClient) Method(ctx context.Context, in *no_package_name.Msg) (* type svc2Server struct { Svc2 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -189,6 +190,7 @@ func NewSvc2Server(svc Svc2, opts ...interface{}) TwirpServer { return &svc2Server{ Svc2: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -283,11 +285,34 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit return } + handler := s.Svc2.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + } + return s.Svc2.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *no_package_name.Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -342,11 +367,34 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response return } + handler := s.Svc2.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + } + return s.Svc2.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *no_package_name.Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc2.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index de26bec2..c0451132 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -164,6 +164,7 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -190,6 +191,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -284,11 +286,34 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -343,11 +368,34 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri return } + handler := s.Svc.Send + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Send(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Send(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index ccb9100d..d8d1923e 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -163,6 +163,7 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er type haberdasherServer struct { Haberdasher + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -189,6 +190,7 @@ func NewHaberdasherServer(svc Haberdasher, opts ...interface{}) TwirpServer { return &haberdasherServer{ Haberdasher: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -283,11 +285,34 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp return } + handler := s.Haberdasher.MakeHat + if s.interceptor != nil { + handler = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("could not convert to a *Size") + } + return s.Haberdasher.MakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("could not convert to a *Hat") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Hat func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Haberdasher.MakeHat(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -342,11 +367,34 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. return } + handler := s.Haberdasher.MakeHat + if s.interceptor != nil { + handler = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("could not convert to a *Size") + } + return s.Haberdasher.MakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("could not convert to a *Hat") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Hat func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Haberdasher.MakeHat(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go index b5dbf037..e30cdc7a 100644 --- a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go +++ b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go @@ -161,6 +161,7 @@ func (c *echoJSONClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { type echoServer struct { Echo + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -187,6 +188,7 @@ func NewEchoServer(svc Echo, opts ...interface{}) TwirpServer { return &echoServer{ Echo: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -281,11 +283,34 @@ func (s *echoServer) serveEchoJSON(ctx context.Context, resp http.ResponseWriter return } + handler := s.Echo.Echo + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Echo.Echo(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Echo.Echo(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -340,11 +365,34 @@ func (s *echoServer) serveEchoProtobuf(ctx context.Context, resp http.ResponseWr return } + handler := s.Echo.Echo + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Echo.Echo(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Echo.Echo(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/service_test.go b/internal/twirptest/service_test.go index 1d934de6..66f5d557 100644 --- a/internal/twirptest/service_test.go +++ b/internal/twirptest/service_test.go @@ -18,6 +18,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "io/ioutil" "net/http" @@ -574,6 +575,48 @@ func TestErroringHooks(t *testing.T) { }) } +func TestInterceptor(t *testing.T) { + interceptor := func(next twirp.Method) twirp.Method { + return func(ctx context.Context, request interface{}) (interface{}, error) { + size, ok := request.(*Size) + if !ok { + return nil, fmt.Errorf("could not cast %T to a *Size", request) + } + size.Inches = size.Inches + 1 + response, err := next(ctx, request) + hat, ok := response.(*Hat) + if ok && hat != nil { + hat.Color = hat.Color + "x" + return hat, err + } + return nil, err + } + } + h := PickyHatmaker(3) + + s := httptest.NewServer( + NewHaberdasherServer( + h, + twirp.WithServerInterceptors( + interceptor, + interceptor, + ), + ), + ) + defer s.Close() + client := NewHaberdasherProtobufClient(s.URL, http.DefaultClient) + hat, clientErr := client.MakeHat(context.Background(), &Size{Inches: 1}) + if clientErr != nil { + t.Fatalf("client err=%q", clientErr) + } + if hat.Size != 3 { + t.Errorf("hat size expected=3 actual=%v", hat.Size) + } + if hat.Color != "bluexx" { + t.Errorf("hat color expected=bluexx actual=%v", hat.Color) + } +} + func TestInternalErrorPassing(t *testing.T) { e := twirp.InternalError("fatal :(") diff --git a/internal/twirptest/snake_case_names/snake_case_names.twirp.go b/internal/twirptest/snake_case_names/snake_case_names.twirp.go index 3afad850..60c8019d 100644 --- a/internal/twirptest/snake_case_names/snake_case_names.twirp.go +++ b/internal/twirptest/snake_case_names/snake_case_names.twirp.go @@ -166,6 +166,7 @@ func (c *haberdasherV1JSONClient) MakeHatV1(ctx context.Context, in *MakeHatArgs type haberdasherV1Server struct { HaberdasherV1 + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -192,6 +193,7 @@ func NewHaberdasherV1Server(svc HaberdasherV1, opts ...interface{}) TwirpServer return &haberdasherV1Server{ HaberdasherV1: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -286,11 +288,34 @@ func (s *haberdasherV1Server) serveMakeHatV1JSON(ctx context.Context, resp http. return } + handler := s.HaberdasherV1.MakeHatV1 + if s.interceptor != nil { + handler = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*MakeHatArgsV1_SizeV1) + if !ok { + return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_SizeV1") + } + return s.HaberdasherV1.MakeHatV1(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*MakeHatArgsV1_HatV1) + if !ok { + return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_HatV1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *MakeHatArgsV1_HatV1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.HaberdasherV1.MakeHatV1(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -345,11 +370,34 @@ func (s *haberdasherV1Server) serveMakeHatV1Protobuf(ctx context.Context, resp h return } + handler := s.HaberdasherV1.MakeHatV1 + if s.interceptor != nil { + handler = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*MakeHatArgsV1_SizeV1) + if !ok { + return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_SizeV1") + } + return s.HaberdasherV1.MakeHatV1(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*MakeHatArgsV1_HatV1) + if !ok { + return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_HatV1") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *MakeHatArgsV1_HatV1 func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.HaberdasherV1.MakeHatV1(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/internal/twirptest/source_relative/source_relative.twirp.go b/internal/twirptest/source_relative/source_relative.twirp.go index 65e81b4c..fea3b252 100644 --- a/internal/twirptest/source_relative/source_relative.twirp.go +++ b/internal/twirptest/source_relative/source_relative.twirp.go @@ -161,6 +161,7 @@ func (c *svcJSONClient) Method(ctx context.Context, in *Msg) (*Msg, error) { type svcServer struct { Svc + interceptor twirp.Interceptor hooks *twirp.ServerHooks pathPrefix string // prefix for routing jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response @@ -187,6 +188,7 @@ func NewSvcServer(svc Svc, opts ...interface{}) TwirpServer { return &svcServer{ Svc: svc, pathPrefix: serverOpts.PathPrefix(), + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), hooks: serverOpts.Hooks, jsonSkipDefaults: serverOpts.JSONSkipDefaults, } @@ -281,11 +283,34 @@ func (s *svcServer) serveMethodJSON(ctx context.Context, resp http.ResponseWrite return } + handler := s.Svc.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { @@ -340,11 +365,34 @@ func (s *svcServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseW return } + handler := s.Svc.Method + if s.interceptor != nil { + handler = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return s.Svc.Method(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("could not convert to a *Msg") + } + return typedResp, err + } + return nil, err + } + } + // Call service method var respContent *Msg func() { defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = s.Svc.Method(ctx, reqContent) + respContent, err = handler(ctx, reqContent) }() if err != nil { diff --git a/middleware.go b/middleware.go new file mode 100644 index 00000000..69631b41 --- /dev/null +++ b/middleware.go @@ -0,0 +1,44 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package twirp + +import ( + "context" +) + +// Method is a method. +type Method func(ctx context.Context, request interface{}) (interface{}, error) + +// Interceptor is a interceptor. +type Interceptor func(Method) Method + +// ChainInterceptors chains the Interceptors. +// +// Returns nil if interceptors is empty. +func ChainInterceptors(interceptors ...Interceptor) Interceptor { + switch n := len(interceptors); n { + case 0: + return nil + case 1: + return interceptors[0] + default: + first := interceptors[0] + return func(next Method) Method { + for i := len(interceptors) - 1; i > 0; i-- { + next = interceptors[i](next) + } + return first(next) + } + } + +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 00000000..ee46bf17 --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,71 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package twirp + +import ( + "context" + "testing" +) + +func TestChainInterceptors(t *testing.T) { + if chain := ChainInterceptors(); chain != nil { + t.Errorf("ChainInterceptors(0) expected to be nil, but was %v", chain) + } + + interceptor1 := func(next Method) Method { + return func(ctx context.Context, request interface{}) (interface{}, error) { + response, err := next(ctx, request.(string)+"a") + return response.(string) + "1", err + } + } + interceptor2 := func(next Method) Method { + return func(ctx context.Context, request interface{}) (interface{}, error) { + response, err := next(ctx, request.(string)+"b") + return response.(string) + "2", err + } + } + interceptor3 := func(next Method) Method { + return func(ctx context.Context, request interface{}) (interface{}, error) { + response, err := next(ctx, request.(string)+"c") + return response.(string) + "3", err + } + } + method := func(ctx context.Context, request interface{}) (interface{}, error) { + return request.(string) + "x", nil + } + for _, testCase := range []struct { + interceptors []Interceptor + want string + }{ + { + interceptors: []Interceptor{interceptor1}, + want: "ax1", + }, + { + interceptors: []Interceptor{interceptor1, interceptor2}, + want: "abx21", + }, + { + interceptors: []Interceptor{interceptor1, interceptor2, interceptor3}, + want: "abcx321", + }, + } { + response, err := ChainInterceptors(testCase.interceptors...)(method)(context.Background(), "") + if err != nil { + t.Fatalf("ChainInterceptors(%d) method has unexpected err %v", len(testCase.interceptors), err) + } + if response != testCase.want { + t.Errorf("ChainInterceptors(%d) has unexpected value, have=%v, want=%v", len(testCase.interceptors), response, testCase.want) + } + } +} diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index 9d758881..8c66777f 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1063,6 +1063,7 @@ func (t *twirp) generateServer(file *descriptor.FileDescriptorProto, service *de servStruct := serviceStruct(service) t.P(`type `, servStruct, ` struct {`) t.P(` `, servName) + t.P(` interceptor `, t.pkgs["twirp"], `.Interceptor`) t.P(` hooks *`, t.pkgs["twirp"], `.ServerHooks`) t.P(` pathPrefix string // prefix for routing`) t.P(` jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response`) @@ -1091,6 +1092,7 @@ func (t *twirp) generateServer(file *descriptor.FileDescriptorProto, service *de t.P(` return &`, servStruct, `{`) t.P(` `, servName, `: svc,`) t.P(` pathPrefix: serverOpts.PathPrefix(),`) + t.P(` interceptor: `, t.pkgs["twirp"], `.ChainInterceptors(serverOpts.Interceptors...),`) t.P(` hooks: serverOpts.Hooks,`) t.P(` jsonSkipDefaults: serverOpts.JSONSkipDefaults,`) t.P(` }`) @@ -1232,11 +1234,16 @@ func (t *twirp) generateServerJSONMethod(service *descriptor.ServiceDescriptorPr t.P(` return`) t.P(` }`) t.P() + t.P(` handler := s.`, servName, `.`, methName) + t.P(` if s.interceptor != nil {`) + t.generateServerInterceptorHandler(service, method) + t.P(` }`) + t.P() t.P(` // Call service method`) t.P(` var respContent *`, t.goTypeName(method.GetOutputType())) t.P(` func() {`) t.P(` defer ensurePanicResponses(ctx, resp, s.hooks)`) - t.P(` respContent, err = s.`, servName, `.`, methName, `(ctx, reqContent)`) + t.P(` respContent, err = handler(ctx, reqContent)`) t.P(` }()`) t.P() t.P(` if err != nil {`) @@ -1297,11 +1304,16 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P(` return`) t.P(` }`) t.P() + t.P(` handler := s.`, servName, `.`, methName) + t.P(` if s.interceptor != nil {`) + t.generateServerInterceptorHandler(service, method) + t.P(` }`) + t.P() t.P(` // Call service method`) t.P(` var respContent *`, t.goTypeName(method.GetOutputType())) t.P(` func() {`) t.P(` defer ensurePanicResponses(ctx, resp, s.hooks)`) - t.P(` respContent, err = s.`, servName, `.`, methName, `(ctx, reqContent)`) + t.P(` respContent, err = handler(ctx, reqContent)`) t.P(` }()`) t.P() t.P(` if err != nil {`) @@ -1335,6 +1347,32 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P() } +func (t *twirp) generateServerInterceptorHandler(service *descriptor.ServiceDescriptorProto, method *descriptor.MethodDescriptorProto) { + methName := methodNameCamelCased(method) + servName := serviceNameCamelCased(service) + inputType := t.goTypeName(method.GetInputType()) + outputType := t.goTypeName(method.GetOutputType()) + t.P(` handler = func(ctx `, t.pkgs["context"], `.Context, req *`, inputType, `) (*`, outputType, `, error) {`) + t.P(` resp, err := s.interceptor(`) + t.P(` func(ctx `, t.pkgs["context"], ` .Context, req interface{}) (interface{}, error) {`) + t.P(` typedReq, ok := req.(*`, inputType, `)`) + t.P(` if !ok {`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("could not convert to a *`, inputType, `")`) + t.P(` }`) + t.P(` return s.`, servName, `.`, methName, `(ctx, typedReq)`) + t.P(` },`) + t.P(` )(ctx, req)`) + t.P(` if resp != nil {`) + t.P(` typedResp, ok := resp.(*`, outputType, `)`) + t.P(` if !ok {`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("could not convert to a *`, outputType, `")`) + t.P(` }`) + t.P(` return typedResp, err`) + t.P(` }`) + t.P(` return nil, err`) + t.P(` }`) +} + // serviceMetadataVarName is the variable name used in generated code to refer // to the compressed bytes of this descriptor. It is not exported, so it is only // valid inside the generated package. diff --git a/server_options.go b/server_options.go index 8f8cf090..e26ee485 100644 --- a/server_options.go +++ b/server_options.go @@ -21,6 +21,7 @@ type ServerOption func(*ServerOptions) // ServerOptions encapsulate the configurable parameters on a Twirp client. type ServerOptions struct { + Interceptors []Interceptor Hooks *ServerHooks pathPrefix *string JSONSkipDefaults bool @@ -33,6 +34,13 @@ func (opts *ServerOptions) PathPrefix() string { return *opts.pathPrefix } +// WithServerInterceptors defines the interceptors for a Twirp server. +func WithServerInterceptors(interceptors ...Interceptor) ServerOption { + return func(o *ServerOptions) { + o.Interceptors = append(o.Interceptors, interceptors...) + } +} + // WithServerHooks defines the hooks for a Twirp server. func WithServerHooks(hooks *ServerHooks) ServerOption { return func(o *ServerOptions) { From f226fd27f9a8fd115dcaabfb5c461b98a8fb017e Mon Sep 17 00:00:00 2001 From: bufdev Date: Fri, 18 Sep 2020 12:00:22 -0400 Subject: [PATCH 2/6] Address comments --- .../clientcompat/clientcompat.twirp.go | 16 ++--- docs/interceptors.md | 32 +++++++++ example/service.twirp.go | 8 +-- interceptors.go | 68 +++++++++++++++++++ middleware_test.go => interceptors_test.go | 14 ++++ .../twirptest/gogo_compat/service.twirp.go | 8 +-- .../google_protobuf_imports/service.twirp.go | 8 +-- .../twirptest/importable/importable.twirp.go | 8 +-- internal/twirptest/importer/importer.twirp.go | 8 +-- .../importer_local/importer_local.twirp.go | 8 +-- internal/twirptest/importmapping/x/x.twirp.go | 8 +-- .../json_serialization.twirp.go | 8 +-- .../twirptest/multiple/multiple1.twirp.go | 8 +-- .../twirptest/multiple/multiple2.twirp.go | 16 ++--- .../no_package_name/no_package_name.twirp.go | 8 +-- .../no_package_name_importer.twirp.go | 8 +-- internal/twirptest/proto/proto.twirp.go | 8 +-- internal/twirptest/service.twirp.go | 8 +-- .../service_method_same_name.twirp.go | 8 +-- internal/twirptest/service_test.go | 38 ++++++++++- .../snake_case_names.twirp.go | 8 +-- .../source_relative/source_relative.twirp.go | 8 +-- middleware.go | 44 ------------ protoc-gen-twirp/generator.go | 4 +- website/sidebars.json | 1 + 25 files changed, 232 insertions(+), 129 deletions(-) create mode 100644 docs/interceptors.md create mode 100644 interceptors.go rename middleware_test.go => interceptors_test.go (82%) delete mode 100644 middleware.go diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index 30744583..6a914bec 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -337,7 +337,7 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Req) if !ok { - return nil, twirp.InternalError("could not convert to a *Req") + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor handler") } return s.CompatService.Method(ctx, typedReq) }, @@ -345,7 +345,7 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res if resp != nil { typedResp, ok := resp.(*Resp) if !ok { - return nil, twirp.InternalError("could not convert to a *Resp") + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor handler") } return typedResp, err } @@ -419,7 +419,7 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Req) if !ok { - return nil, twirp.InternalError("could not convert to a *Req") + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor handler") } return s.CompatService.Method(ctx, typedReq) }, @@ -427,7 +427,7 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http if resp != nil { typedResp, ok := resp.(*Resp) if !ok { - return nil, twirp.InternalError("could not convert to a *Resp") + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor handler") } return typedResp, err } @@ -512,7 +512,7 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *Empty") + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor handler") } return s.CompatService.NoopMethod(ctx, typedReq) }, @@ -520,7 +520,7 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http if resp != nil { typedResp, ok := resp.(*Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *Empty") + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor handler") } return typedResp, err } @@ -594,7 +594,7 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *Empty") + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor handler") } return s.CompatService.NoopMethod(ctx, typedReq) }, @@ -602,7 +602,7 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp if resp != nil { typedResp, ok := resp.(*Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *Empty") + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor handler") } return typedResp, err } diff --git a/docs/interceptors.md b/docs/interceptors.md new file mode 100644 index 00000000..a460c994 --- /dev/null +++ b/docs/interceptors.md @@ -0,0 +1,32 @@ +--- +id: "interceptors" +title: "Interceptors" +sidebar_label: "Interceptor" +--- + +The service constructor can use the option `twirp.WithServerInterceptors(interceptors ...twirp.Interceptor)` +to plug in additional functionality: + +```go +server := NewHaberdasherServer(svcImpl, twirp.WithInterceptor(NewLogInterceptor(logger.New(os.Stderr, "", 0)))) + +// NewLogInterceptor logs various parts of a request using a standard Logger. +func NewLogInterceptor(l *log.Logger) twirp.Interceptor { + return func(next twirp.Method) twirp.Method { + return func(ctx context.Context, req interface{}) (interface{}, error) { + l.Printf("request: %v", request) + resp, err := next(ctx, req) + if err != nil { + l.Printf("error: %v", err) + return nil, err + } + l.Printf("response: %v", resp) + return resp, nil + } + } +} +``` + +Check out +[the godoc for `Interceptor`](http://godoc.org/github.com/twitchtv/twirp#Interceptors) +for more information. diff --git a/example/service.twirp.go b/example/service.twirp.go index 477295b3..196401d1 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -292,7 +292,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("could not convert to a *Size") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -300,7 +300,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("could not convert to a *Hat") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") } return typedResp, err } @@ -374,7 +374,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("could not convert to a *Size") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -382,7 +382,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("could not convert to a *Hat") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") } return typedResp, err } diff --git a/interceptors.go b/interceptors.go new file mode 100644 index 00000000..971154d9 --- /dev/null +++ b/interceptors.go @@ -0,0 +1,68 @@ +// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may not +// use this file except in compliance with the License. A copy of the License is +// located at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// or in the "license" file accompanying this file. This file is distributed on +// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package twirp + +import ( + "context" +) + +// Method is a method that matches the generic form of a Twirp-generated RPC method. +// +// This is used for Interceptors. +type Method func(ctx context.Context, request interface{}) (interface{}, error) + +// Interceptor is an interceptor that can be installed on a client or server. +// +// Users can use Interceptors to intercept any RPC. +// +// func LogInterceptor(l *log.Logger) twirp.Interceptor { +// return func(next twirp.Method) twirp.Method { +// return func(ctx context.Context, req interface{}) (interface{}, error) { +// l.Printf("request: %v", request) +// resp, err := next(ctx, req) +// if err != nil { +// l.Printf("error: %v", err) +// return nil, err +// } +// l.Printf("response: %v", resp) +// return resp, nil +// } +// } +// } +type Interceptor func(Method) Method + +// ChainInterceptors chains the Interceptors. +// +// Returns nil if interceptors is empty. +func ChainInterceptors(interceptors ...Interceptor) Interceptor { + filtered := make([]Interceptor, 0, len(interceptors)) + for _, interceptor := range interceptors { + if interceptor != nil { + filtered = append(filtered, interceptor) + } + } + switch n := len(filtered); n { + case 0: + return nil + case 1: + return filtered[0] + default: + first := filtered[0] + return func(next Method) Method { + for i := len(filtered) - 1; i > 0; i-- { + next = filtered[i](next) + } + return first(next) + } + } +} diff --git a/middleware_test.go b/interceptors_test.go similarity index 82% rename from middleware_test.go rename to interceptors_test.go index ee46bf17..8e9c00a1 100644 --- a/middleware_test.go +++ b/interceptors_test.go @@ -21,6 +21,12 @@ func TestChainInterceptors(t *testing.T) { if chain := ChainInterceptors(); chain != nil { t.Errorf("ChainInterceptors(0) expected to be nil, but was %v", chain) } + if chain := ChainInterceptors(nil); chain != nil { + t.Errorf("ChainInterceptors(0) expected to be nil, but was %v", chain) + } + if chain := ChainInterceptors(nil, nil); chain != nil { + t.Errorf("ChainInterceptors(0) expected to be nil, but was %v", chain) + } interceptor1 := func(next Method) Method { return func(ctx context.Context, request interface{}) (interface{}, error) { @@ -59,6 +65,14 @@ func TestChainInterceptors(t *testing.T) { interceptors: []Interceptor{interceptor1, interceptor2, interceptor3}, want: "abcx321", }, + { + interceptors: []Interceptor{interceptor1, interceptor2, nil, interceptor3}, + want: "abcx321", + }, + { + interceptors: []Interceptor{interceptor1, interceptor1, interceptor1}, + want: "aaax111", + }, } { response, err := ChainInterceptors(testCase.interceptors...)(method)(context.Background(), "") if err != nil { diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index 8be65da9..2e5b86fb 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -294,7 +294,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -302,7 +302,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -376,7 +376,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -384,7 +384,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/google_protobuf_imports/service.twirp.go b/internal/twirptest/google_protobuf_imports/service.twirp.go index fc86467d..8f3ffc45 100644 --- a/internal/twirptest/google_protobuf_imports/service.twirp.go +++ b/internal/twirptest/google_protobuf_imports/service.twirp.go @@ -293,7 +293,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*google_protobuf1.StringValue) if !ok { - return nil, twirp.InternalError("could not convert to a *google_protobuf1.StringValue") + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +301,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*google_protobuf.Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *google_protobuf.Empty") + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor handler") } return typedResp, err } @@ -375,7 +375,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*google_protobuf1.StringValue) if !ok { - return nil, twirp.InternalError("could not convert to a *google_protobuf1.StringValue") + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +383,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*google_protobuf.Empty) if !ok { - return nil, twirp.InternalError("could not convert to a *google_protobuf.Empty") + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index 83584dec..08009e90 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -293,7 +293,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +301,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -375,7 +375,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +383,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index 940d4eea..e26cd58c 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -295,7 +295,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") } return s.Svc2.Send(ctx, typedReq) }, @@ -303,7 +303,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") } return typedResp, err } @@ -377,7 +377,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") } return s.Svc2.Send(ctx, typedReq) }, @@ -385,7 +385,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importable.Msg") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/importer_local/importer_local.twirp.go b/internal/twirptest/importer_local/importer_local.twirp.go index 55a627f8..42da05af 100644 --- a/internal/twirptest/importer_local/importer_local.twirp.go +++ b/internal/twirptest/importer_local/importer_local.twirp.go @@ -290,7 +290,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -298,7 +298,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -372,7 +372,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -380,7 +380,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/importmapping/x/x.twirp.go b/internal/twirptest/importmapping/x/x.twirp.go index 023006b5..567aa039 100644 --- a/internal/twirptest/importmapping/x/x.twirp.go +++ b/internal/twirptest/importmapping/x/x.twirp.go @@ -292,7 +292,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") } return s.Svc1.Send(ctx, typedReq) }, @@ -300,7 +300,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") } return typedResp, err } @@ -374,7 +374,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") } return s.Svc1.Send(ctx, typedReq) }, @@ -382,7 +382,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("could not convert to a *twirp_internal_twirptest_importmapping_y.MsgY") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/json_serialization/json_serialization.twirp.go b/internal/twirptest/json_serialization/json_serialization.twirp.go index e1940138..10090dfd 100644 --- a/internal/twirptest/json_serialization/json_serialization.twirp.go +++ b/internal/twirptest/json_serialization/json_serialization.twirp.go @@ -290,7 +290,7 @@ func (s *jSONSerializationServer) serveEchoJSONJSON(ctx context.Context, resp ht func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.JSONSerialization.EchoJSON(ctx, typedReq) }, @@ -298,7 +298,7 @@ func (s *jSONSerializationServer) serveEchoJSONJSON(ctx context.Context, resp ht if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -372,7 +372,7 @@ func (s *jSONSerializationServer) serveEchoJSONProtobuf(ctx context.Context, res func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.JSONSerialization.EchoJSON(ctx, typedReq) }, @@ -380,7 +380,7 @@ func (s *jSONSerializationServer) serveEchoJSONProtobuf(ctx context.Context, res if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index 80912310..b5590cdf 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -294,7 +294,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") } return s.Svc1.Send(ctx, typedReq) }, @@ -302,7 +302,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") } return typedResp, err } @@ -376,7 +376,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") } return s.Svc1.Send(ctx, typedReq) }, @@ -384,7 +384,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index 2a7b9570..ed49dd0e 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -324,7 +324,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg2) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg2") + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor handler") } return s.Svc2.Send(ctx, typedReq) }, @@ -332,7 +332,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg2) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg2") + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor handler") } return typedResp, err } @@ -406,7 +406,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg2) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg2") + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor handler") } return s.Svc2.Send(ctx, typedReq) }, @@ -414,7 +414,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg2) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg2") + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor handler") } return typedResp, err } @@ -499,7 +499,7 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") } return s.Svc2.SamePackageProtoImport(ctx, typedReq) }, @@ -507,7 +507,7 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") } return typedResp, err } @@ -581,7 +581,7 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") } return s.Svc2.SamePackageProtoImport(ctx, typedReq) }, @@ -589,7 +589,7 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg1") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 0b6e8579..476538fa 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -290,7 +290,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -298,7 +298,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -372,7 +372,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -380,7 +380,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index 281293cf..323a88ca 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -292,7 +292,7 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor handler") } return s.Svc2.Method(ctx, typedReq) }, @@ -300,7 +300,7 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit if resp != nil { typedResp, ok := resp.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor handler") } return typedResp, err } @@ -374,7 +374,7 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor handler") } return s.Svc2.Method(ctx, typedReq) }, @@ -382,7 +382,7 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response if resp != nil { typedResp, ok := resp.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *no_package_name.Msg") + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index c0451132..f1b6cf1f 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -293,7 +293,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +301,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -375,7 +375,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +383,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index d8d1923e..58ce4491 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -292,7 +292,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("could not convert to a *Size") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -300,7 +300,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("could not convert to a *Hat") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") } return typedResp, err } @@ -374,7 +374,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("could not convert to a *Size") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -382,7 +382,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("could not convert to a *Hat") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go index e30cdc7a..b0dc7391 100644 --- a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go +++ b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go @@ -290,7 +290,7 @@ func (s *echoServer) serveEchoJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Echo.Echo(ctx, typedReq) }, @@ -298,7 +298,7 @@ func (s *echoServer) serveEchoJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -372,7 +372,7 @@ func (s *echoServer) serveEchoProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Echo.Echo(ctx, typedReq) }, @@ -380,7 +380,7 @@ func (s *echoServer) serveEchoProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/service_test.go b/internal/twirptest/service_test.go index 66f5d557..0c7e888a 100644 --- a/internal/twirptest/service_test.go +++ b/internal/twirptest/service_test.go @@ -578,6 +578,18 @@ func TestErroringHooks(t *testing.T) { func TestInterceptor(t *testing.T) { interceptor := func(next twirp.Method) twirp.Method { return func(ctx context.Context, request interface{}) (interface{}, error) { + methodName, _ := twirp.MethodName(ctx) + if methodName != "MakeHat" { + return nil, fmt.Errorf("unexpected methodName: %q", methodName) + } + serviceName, _ := twirp.ServiceName(ctx) + if serviceName != "Haberdasher" { + return nil, fmt.Errorf("unexpected serviceName: %q", serviceName) + } + packageName, _ := twirp.PackageName(ctx) + if packageName != "twirp.internal.twirptest" { + return nil, fmt.Errorf("unexpected packageName: %q", packageName) + } size, ok := request.(*Size) if !ok { return nil, fmt.Errorf("could not cast %T to a *Size", request) @@ -594,7 +606,18 @@ func TestInterceptor(t *testing.T) { } h := PickyHatmaker(3) - s := httptest.NewServer( + s := httptest.NewServer(NewHaberdasherServer(h)) + defer s.Close() + client := NewHaberdasherProtobufClient(s.URL, http.DefaultClient) + hat, clientErr := client.MakeHat(context.Background(), &Size{Inches: 3}) + if clientErr != nil { + t.Fatalf("client err=%q", clientErr) + } + if hat.Size != 3 { + t.Errorf("hat size expected=3 actual=%v", hat.Size) + } + + s = httptest.NewServer( NewHaberdasherServer( h, twirp.WithServerInterceptors( @@ -604,8 +627,8 @@ func TestInterceptor(t *testing.T) { ), ) defer s.Close() - client := NewHaberdasherProtobufClient(s.URL, http.DefaultClient) - hat, clientErr := client.MakeHat(context.Background(), &Size{Inches: 1}) + client = NewHaberdasherProtobufClient(s.URL, http.DefaultClient) + hat, clientErr = client.MakeHat(context.Background(), &Size{Inches: 1}) if clientErr != nil { t.Fatalf("client err=%q", clientErr) } @@ -615,6 +638,15 @@ func TestInterceptor(t *testing.T) { if hat.Color != "bluexx" { t.Errorf("hat color expected=bluexx actual=%v", hat.Color) } + _, clientErr = client.MakeHat(context.Background(), &Size{Inches: 3}) + twerr, ok := clientErr.(twirp.Error) + if !ok { + t.Fatalf("expected twirp.Error type error, have %T", clientErr) + } + + if twerr.Code() != twirp.InvalidArgument { + t.Errorf("expected error type to be InvalidArgument, buf found %q", twerr.Code()) + } } func TestInternalErrorPassing(t *testing.T) { diff --git a/internal/twirptest/snake_case_names/snake_case_names.twirp.go b/internal/twirptest/snake_case_names/snake_case_names.twirp.go index 60c8019d..9b866322 100644 --- a/internal/twirptest/snake_case_names/snake_case_names.twirp.go +++ b/internal/twirptest/snake_case_names/snake_case_names.twirp.go @@ -295,7 +295,7 @@ func (s *haberdasherV1Server) serveMakeHatV1JSON(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*MakeHatArgsV1_SizeV1) if !ok { - return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_SizeV1") + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor handler") } return s.HaberdasherV1.MakeHatV1(ctx, typedReq) }, @@ -303,7 +303,7 @@ func (s *haberdasherV1Server) serveMakeHatV1JSON(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*MakeHatArgsV1_HatV1) if !ok { - return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_HatV1") + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor handler") } return typedResp, err } @@ -377,7 +377,7 @@ func (s *haberdasherV1Server) serveMakeHatV1Protobuf(ctx context.Context, resp h func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*MakeHatArgsV1_SizeV1) if !ok { - return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_SizeV1") + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor handler") } return s.HaberdasherV1.MakeHatV1(ctx, typedReq) }, @@ -385,7 +385,7 @@ func (s *haberdasherV1Server) serveMakeHatV1Protobuf(ctx context.Context, resp h if resp != nil { typedResp, ok := resp.(*MakeHatArgsV1_HatV1) if !ok { - return nil, twirp.InternalError("could not convert to a *MakeHatArgsV1_HatV1") + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor handler") } return typedResp, err } diff --git a/internal/twirptest/source_relative/source_relative.twirp.go b/internal/twirptest/source_relative/source_relative.twirp.go index fea3b252..fd6de8c7 100644 --- a/internal/twirptest/source_relative/source_relative.twirp.go +++ b/internal/twirptest/source_relative/source_relative.twirp.go @@ -290,7 +290,7 @@ func (s *svcServer) serveMethodJSON(ctx context.Context, resp http.ResponseWrite func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Method(ctx, typedReq) }, @@ -298,7 +298,7 @@ func (s *svcServer) serveMethodJSON(ctx context.Context, resp http.ResponseWrite if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } @@ -372,7 +372,7 @@ func (s *svcServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseW func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") } return s.Svc.Method(ctx, typedReq) }, @@ -380,7 +380,7 @@ func (s *svcServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseW if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("could not convert to a *Msg") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") } return typedResp, err } diff --git a/middleware.go b/middleware.go deleted file mode 100644 index 69631b41..00000000 --- a/middleware.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may not -// use this file except in compliance with the License. A copy of the License is -// located at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// or in the "license" file accompanying this file. This file is distributed on -// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package twirp - -import ( - "context" -) - -// Method is a method. -type Method func(ctx context.Context, request interface{}) (interface{}, error) - -// Interceptor is a interceptor. -type Interceptor func(Method) Method - -// ChainInterceptors chains the Interceptors. -// -// Returns nil if interceptors is empty. -func ChainInterceptors(interceptors ...Interceptor) Interceptor { - switch n := len(interceptors); n { - case 0: - return nil - case 1: - return interceptors[0] - default: - first := interceptors[0] - return func(next Method) Method { - for i := len(interceptors) - 1; i > 0; i-- { - next = interceptors[i](next) - } - return first(next) - } - } - -} diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index 8c66777f..938e2585 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1357,7 +1357,7 @@ func (t *twirp) generateServerInterceptorHandler(service *descriptor.ServiceDesc t.P(` func(ctx `, t.pkgs["context"], ` .Context, req interface{}) (interface{}, error) {`) t.P(` typedReq, ok := req.(*`, inputType, `)`) t.P(` if !ok {`) - t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("could not convert to a *`, inputType, `")`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion req.(*`, inputType, `) when calling interceptor handler")`) t.P(` }`) t.P(` return s.`, servName, `.`, methName, `(ctx, typedReq)`) t.P(` },`) @@ -1365,7 +1365,7 @@ func (t *twirp) generateServerInterceptorHandler(service *descriptor.ServiceDesc t.P(` if resp != nil {`) t.P(` typedResp, ok := resp.(*`, outputType, `)`) t.P(` if !ok {`) - t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("could not convert to a *`, outputType, `")`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion resp.(*`, outputType, `) when calling interceptor handler")`) t.P(` }`) t.P(` return typedResp, err`) t.P(` }`) diff --git a/website/sidebars.json b/website/sidebars.json index 6d526022..90e65eb4 100644 --- a/website/sidebars.json +++ b/website/sidebars.json @@ -12,6 +12,7 @@ "routing", "errors", "hooks", + "interceptors", "command_line", "proto_and_json", "curl" From 6f1c4c34442b1a9fda9b4a35edd4675b6e99f2ad Mon Sep 17 00:00:00 2001 From: bufdev Date: Fri, 18 Sep 2020 12:04:27 -0400 Subject: [PATCH 3/6] Fix docs --- docs/interceptors.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/interceptors.md b/docs/interceptors.md index a460c994..872a7903 100644 --- a/docs/interceptors.md +++ b/docs/interceptors.md @@ -28,5 +28,5 @@ func NewLogInterceptor(l *log.Logger) twirp.Interceptor { ``` Check out -[the godoc for `Interceptor`](http://godoc.org/github.com/twitchtv/twirp#Interceptors) +[the godoc for `Interceptor`](http://godoc.org/github.com/twitchtv/twirp#Interceptor) for more information. From 9a12f1177d6f5fd8b469f472b67acd16c0ee7e2e Mon Sep 17 00:00:00 2001 From: bufdev Date: Fri, 18 Sep 2020 12:05:04 -0400 Subject: [PATCH 4/6] Fix docs --- docs/interceptors.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/interceptors.md b/docs/interceptors.md index 872a7903..895ec544 100644 --- a/docs/interceptors.md +++ b/docs/interceptors.md @@ -1,7 +1,7 @@ --- id: "interceptors" title: "Interceptors" -sidebar_label: "Interceptor" +sidebar_label: "Interceptors" --- The service constructor can use the option `twirp.WithServerInterceptors(interceptors ...twirp.Interceptor)` From 2688b1f5c3be3ef93dc26468741b07bcff67ce44 Mon Sep 17 00:00:00 2001 From: bufdev Date: Fri, 18 Sep 2020 12:51:08 -0400 Subject: [PATCH 5/6] Add client interceptors --- client_options.go | 12 +- .../clientcompat/clientcompat.twirp.go | 148 +++++++++++++++--- example/service.twirp.go | 88 +++++++++-- .../empty_service/empty_service.twirp.go | 28 ++-- .../twirptest/gogo_compat/service.twirp.go | 88 +++++++++-- .../google_protobuf_imports/service.twirp.go | 88 +++++++++-- .../twirptest/importable/importable.twirp.go | 88 +++++++++-- internal/twirptest/importer/importer.twirp.go | 88 +++++++++-- .../importer_local/importer_local.twirp.go | 88 +++++++++-- internal/twirptest/importmapping/x/x.twirp.go | 88 +++++++++-- .../json_serialization.twirp.go | 88 +++++++++-- .../twirptest/multiple/multiple1.twirp.go | 88 +++++++++-- .../twirptest/multiple/multiple2.twirp.go | 148 +++++++++++++++--- .../no_package_name/no_package_name.twirp.go | 88 +++++++++-- .../no_package_name_importer.twirp.go | 88 +++++++++-- internal/twirptest/proto/proto.twirp.go | 88 +++++++++-- internal/twirptest/service.twirp.go | 88 +++++++++-- .../service_method_same_name.twirp.go | 88 +++++++++-- .../snake_case_names.twirp.go | 88 +++++++++-- .../source_relative/source_relative.twirp.go | 88 +++++++++-- protoc-gen-twirp/generator.go | 35 ++++- 21 files changed, 1463 insertions(+), 316 deletions(-) diff --git a/client_options.go b/client_options.go index 37f817d4..66cfc2f5 100644 --- a/client_options.go +++ b/client_options.go @@ -22,8 +22,9 @@ type ClientOption func(*ClientOptions) // ClientOptions encapsulate the configurable parameters on a Twirp client. type ClientOptions struct { - Hooks *ClientHooks - pathPrefix *string + Interceptors []Interceptor + Hooks *ClientHooks + pathPrefix *string } func (opts *ClientOptions) PathPrefix() string { @@ -33,6 +34,13 @@ func (opts *ClientOptions) PathPrefix() string { return *opts.pathPrefix } +// WithClientInterceptors defines the interceptors for a Twirp client. +func WithClientInterceptors(interceptors ...Interceptor) ClientOption { + return func(o *ClientOptions) { + o.Interceptors = append(o.Interceptors, interceptors...) + } +} + // WithClientHooks defines the hooks for a Twirp client. func WithClientHooks(hooks *ClientHooks) ClientOption { return func(o *ClientOptions) { diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index 6a914bec..e403bb94 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -50,9 +50,10 @@ type CompatService interface { // ============================= type compatServiceProtobufClient struct { - client HTTPClient - urls [2]string - opts twirp.ClientOptions + client HTTPClient + urls [2]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewCompatServiceProtobufClient creates a Protobuf client that implements the CompatService interface. @@ -76,13 +77,40 @@ func NewCompatServiceProtobufClient(baseURL string, client HTTPClient, opts ...t } return &compatServiceProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *compatServiceProtobufClient) Method(ctx context.Context, in *Req) (*Resp, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Req) (*Resp, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Req) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Resp) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *compatServiceProtobufClient) callMethod(ctx context.Context, in *Req) (*Resp, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") ctx = ctxsetters.WithServiceName(ctx, "CompatService") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -103,6 +131,32 @@ func (c *compatServiceProtobufClient) Method(ctx context.Context, in *Req) (*Res } func (c *compatServiceProtobufClient) NoopMethod(ctx context.Context, in *Empty) (*Empty, error) { + caller := c.callNoopMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Empty) (*Empty, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor") + } + return c.callNoopMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *compatServiceProtobufClient) callNoopMethod(ctx context.Context, in *Empty) (*Empty, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") ctx = ctxsetters.WithServiceName(ctx, "CompatService") ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") @@ -127,9 +181,10 @@ func (c *compatServiceProtobufClient) NoopMethod(ctx context.Context, in *Empty) // ========================= type compatServiceJSONClient struct { - client HTTPClient - urls [2]string - opts twirp.ClientOptions + client HTTPClient + urls [2]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewCompatServiceJSONClient creates a JSON client that implements the CompatService interface. @@ -153,13 +208,40 @@ func NewCompatServiceJSONClient(baseURL string, client HTTPClient, opts ...twirp } return &compatServiceJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *compatServiceJSONClient) Method(ctx context.Context, in *Req) (*Resp, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Req) (*Resp, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Req) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Resp) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *compatServiceJSONClient) callMethod(ctx context.Context, in *Req) (*Resp, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") ctx = ctxsetters.WithServiceName(ctx, "CompatService") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -180,6 +262,32 @@ func (c *compatServiceJSONClient) Method(ctx context.Context, in *Req) (*Resp, e } func (c *compatServiceJSONClient) NoopMethod(ctx context.Context, in *Empty) (*Empty, error) { + caller := c.callNoopMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Empty) (*Empty, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor") + } + return c.callNoopMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *compatServiceJSONClient) callNoopMethod(ctx context.Context, in *Empty) (*Empty, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") ctx = ctxsetters.WithServiceName(ctx, "CompatService") ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") @@ -337,7 +445,7 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Req) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor") } return s.CompatService.Method(ctx, typedReq) }, @@ -345,7 +453,7 @@ func (s *compatServiceServer) serveMethodJSON(ctx context.Context, resp http.Res if resp != nil { typedResp, ok := resp.(*Resp) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor") } return typedResp, err } @@ -419,7 +527,7 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Req) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Req) when calling interceptor") } return s.CompatService.Method(ctx, typedReq) }, @@ -427,7 +535,7 @@ func (s *compatServiceServer) serveMethodProtobuf(ctx context.Context, resp http if resp != nil { typedResp, ok := resp.(*Resp) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Resp) when calling interceptor") } return typedResp, err } @@ -512,7 +620,7 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Empty) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor") } return s.CompatService.NoopMethod(ctx, typedReq) }, @@ -520,7 +628,7 @@ func (s *compatServiceServer) serveNoopMethodJSON(ctx context.Context, resp http if resp != nil { typedResp, ok := resp.(*Empty) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor") } return typedResp, err } @@ -594,7 +702,7 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Empty) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Empty) when calling interceptor") } return s.CompatService.NoopMethod(ctx, typedReq) }, @@ -602,7 +710,7 @@ func (s *compatServiceServer) serveNoopMethodProtobuf(ctx context.Context, resp if resp != nil { typedResp, ok := resp.(*Empty) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Empty) when calling interceptor") } return typedResp, err } diff --git a/example/service.twirp.go b/example/service.twirp.go index 196401d1..bd1497b2 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -50,9 +50,10 @@ type Haberdasher interface { // =========================== type haberdasherProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherProtobufClient creates a Protobuf client that implements the Haberdasher interface. @@ -75,13 +76,40 @@ func NewHaberdasherProtobufClient(baseURL string, client HTTPClient, opts ...twi } return &haberdasherProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + caller := c.callMakeHat + if c.interceptor != nil { + caller = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") + } + return c.callMakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherProtobufClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHat") @@ -106,9 +134,10 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat // ======================= type haberdasherJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherJSONClient creates a JSON client that implements the Haberdasher interface. @@ -131,13 +160,40 @@ func NewHaberdasherJSONClient(baseURL string, client HTTPClient, opts ...twirp.C } return &haberdasherJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + caller := c.callMakeHat + if c.interceptor != nil { + caller = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") + } + return c.callMakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherJSONClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHat") @@ -292,7 +348,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -300,7 +356,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") } return typedResp, err } @@ -374,7 +430,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -382,7 +438,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/empty_service/empty_service.twirp.go b/internal/twirptest/empty_service/empty_service.twirp.go index 034ca951..7d4baff0 100644 --- a/internal/twirptest/empty_service/empty_service.twirp.go +++ b/internal/twirptest/empty_service/empty_service.twirp.go @@ -47,9 +47,10 @@ type Empty interface { // ===================== type emptyProtobufClient struct { - client HTTPClient - urls [0]string - opts twirp.ClientOptions + client HTTPClient + urls [0]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewEmptyProtobufClient creates a Protobuf client that implements the Empty interface. @@ -67,9 +68,10 @@ func NewEmptyProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Cli urls := [0]string{} return &emptyProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } @@ -78,9 +80,10 @@ func NewEmptyProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Cli // ================= type emptyJSONClient struct { - client HTTPClient - urls [0]string - opts twirp.ClientOptions + client HTTPClient + urls [0]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewEmptyJSONClient creates a JSON client that implements the Empty interface. @@ -98,9 +101,10 @@ func NewEmptyJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientO urls := [0]string{} return &emptyJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index 2e5b86fb..1e742bc9 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -52,9 +52,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -77,13 +78,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -108,9 +136,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -133,13 +162,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -294,7 +350,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -302,7 +358,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -376,7 +432,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -384,7 +440,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/google_protobuf_imports/service.twirp.go b/internal/twirptest/google_protobuf_imports/service.twirp.go index 8f3ffc45..2c7472b0 100644 --- a/internal/twirptest/google_protobuf_imports/service.twirp.go +++ b/internal/twirptest/google_protobuf_imports/service.twirp.go @@ -51,9 +51,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -76,13 +77,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*google_protobuf1.StringValue) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*google_protobuf.Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -107,9 +135,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *google_protobuf1.Strin // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -132,13 +161,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*google_protobuf1.StringValue) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*google_protobuf.Empty) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -293,7 +349,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*google_protobuf1.StringValue) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +357,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*google_protobuf.Empty) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor") } return typedResp, err } @@ -375,7 +431,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*google_protobuf1.StringValue) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*google_protobuf1.StringValue) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +439,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*google_protobuf.Empty) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*google_protobuf.Empty) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index 08009e90..e30ccb47 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -51,9 +51,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -76,13 +77,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -107,9 +135,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -132,13 +161,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -293,7 +349,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +357,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -375,7 +431,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +439,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index e26cd58c..bc887ccf 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -53,9 +53,10 @@ type Svc2 interface { // ==================== type svc2ProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2ProtobufClient creates a Protobuf client that implements the Svc2 interface. @@ -78,13 +79,40 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &svc2ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2ProtobufClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -109,9 +137,10 @@ func (c *svc2ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirpt // ================ type svc2JSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2JSONClient creates a JSON client that implements the Svc2 interface. @@ -134,13 +163,40 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &svc2JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2JSONClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -295,7 +351,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") } return s.Svc2.Send(ctx, typedReq) }, @@ -303,7 +359,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") } return typedResp, err } @@ -377,7 +433,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") } return s.Svc2.Send(ctx, typedReq) }, @@ -385,7 +441,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importable.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importable.Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/importer_local/importer_local.twirp.go b/internal/twirptest/importer_local/importer_local.twirp.go index 42da05af..148681b8 100644 --- a/internal/twirptest/importer_local/importer_local.twirp.go +++ b/internal/twirptest/importer_local/importer_local.twirp.go @@ -48,9 +48,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -73,13 +74,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -104,9 +132,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -129,13 +158,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -290,7 +346,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -298,7 +354,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -372,7 +428,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -380,7 +436,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/importmapping/x/x.twirp.go b/internal/twirptest/importmapping/x/x.twirp.go index 567aa039..ce1ac527 100644 --- a/internal/twirptest/importmapping/x/x.twirp.go +++ b/internal/twirptest/importmapping/x/x.twirp.go @@ -50,9 +50,10 @@ type Svc1 interface { // ==================== type svc1ProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc1ProtobufClient creates a Protobuf client that implements the Svc1 interface. @@ -75,13 +76,40 @@ func NewSvc1ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &svc1ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc1ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc1ProtobufClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") ctx = ctxsetters.WithServiceName(ctx, "Svc1") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -106,9 +134,10 @@ func (c *svc1ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirpt // ================ type svc1JSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc1JSONClient creates a JSON client that implements the Svc1 interface. @@ -131,13 +160,40 @@ func NewSvc1JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &svc1JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc1JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc1JSONClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") ctx = ctxsetters.WithServiceName(ctx, "Svc1") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -292,7 +348,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") } return s.Svc1.Send(ctx, typedReq) }, @@ -300,7 +356,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") } return typedResp, err } @@ -374,7 +430,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") } return s.Svc1.Send(ctx, typedReq) }, @@ -382,7 +438,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*twirp_internal_twirptest_importmapping_y.MsgY) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*twirp_internal_twirptest_importmapping_y.MsgY) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/json_serialization/json_serialization.twirp.go b/internal/twirptest/json_serialization/json_serialization.twirp.go index 10090dfd..14aaeec1 100644 --- a/internal/twirptest/json_serialization/json_serialization.twirp.go +++ b/internal/twirptest/json_serialization/json_serialization.twirp.go @@ -48,9 +48,10 @@ type JSONSerialization interface { // ================================= type jSONSerializationProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewJSONSerializationProtobufClient creates a Protobuf client that implements the JSONSerialization interface. @@ -73,13 +74,40 @@ func NewJSONSerializationProtobufClient(baseURL string, client HTTPClient, opts } return &jSONSerializationProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *jSONSerializationProtobufClient) EchoJSON(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callEchoJSON + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callEchoJSON(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *jSONSerializationProtobufClient) callEchoJSON(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") @@ -104,9 +132,10 @@ func (c *jSONSerializationProtobufClient) EchoJSON(ctx context.Context, in *Msg) // ============================= type jSONSerializationJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewJSONSerializationJSONClient creates a JSON client that implements the JSONSerialization interface. @@ -129,13 +158,40 @@ func NewJSONSerializationJSONClient(baseURL string, client HTTPClient, opts ...t } return &jSONSerializationJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *jSONSerializationJSONClient) EchoJSON(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callEchoJSON + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callEchoJSON(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *jSONSerializationJSONClient) callEchoJSON(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") @@ -290,7 +346,7 @@ func (s *jSONSerializationServer) serveEchoJSONJSON(ctx context.Context, resp ht func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.JSONSerialization.EchoJSON(ctx, typedReq) }, @@ -298,7 +354,7 @@ func (s *jSONSerializationServer) serveEchoJSONJSON(ctx context.Context, resp ht if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -372,7 +428,7 @@ func (s *jSONSerializationServer) serveEchoJSONProtobuf(ctx context.Context, res func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.JSONSerialization.EchoJSON(ctx, typedReq) }, @@ -380,7 +436,7 @@ func (s *jSONSerializationServer) serveEchoJSONProtobuf(ctx context.Context, res if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index b5590cdf..646475cf 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -52,9 +52,10 @@ type Svc1 interface { // ==================== type svc1ProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc1ProtobufClient creates a Protobuf client that implements the Svc1 interface. @@ -77,13 +78,40 @@ func NewSvc1ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &svc1ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc1ProtobufClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc1ProtobufClient) callSend(ctx context.Context, in *Msg1) (*Msg1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc1") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -108,9 +136,10 @@ func (c *svc1ProtobufClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) // ================ type svc1JSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc1JSONClient creates a JSON client that implements the Svc1 interface. @@ -133,13 +162,40 @@ func NewSvc1JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &svc1JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc1JSONClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc1JSONClient) callSend(ctx context.Context, in *Msg1) (*Msg1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc1") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -294,7 +350,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") } return s.Svc1.Send(ctx, typedReq) }, @@ -302,7 +358,7 @@ func (s *svc1Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") } return typedResp, err } @@ -376,7 +432,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") } return s.Svc1.Send(ctx, typedReq) }, @@ -384,7 +440,7 @@ func (s *svc1Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index ed49dd0e..bf55d2ae 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -37,9 +37,10 @@ type Svc2 interface { // ==================== type svc2ProtobufClient struct { - client HTTPClient - urls [2]string - opts twirp.ClientOptions + client HTTPClient + urls [2]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2ProtobufClient creates a Protobuf client that implements the Svc2 interface. @@ -63,13 +64,40 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &svc2ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2ProtobufClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg2) (*Msg2, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg2) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg2) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2ProtobufClient) callSend(ctx context.Context, in *Msg2) (*Msg2, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -90,6 +118,32 @@ func (c *svc2ProtobufClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) } func (c *svc2ProtobufClient) SamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { + caller := c.callSamePackageProtoImport + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") + } + return c.callSamePackageProtoImport(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2ProtobufClient) callSamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") @@ -114,9 +168,10 @@ func (c *svc2ProtobufClient) SamePackageProtoImport(ctx context.Context, in *Msg // ================ type svc2JSONClient struct { - client HTTPClient - urls [2]string - opts twirp.ClientOptions + client HTTPClient + urls [2]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2JSONClient creates a JSON client that implements the Svc2 interface. @@ -140,13 +195,40 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &svc2JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2JSONClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg2) (*Msg2, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg2) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg2) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2JSONClient) callSend(ctx context.Context, in *Msg2) (*Msg2, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -167,6 +249,32 @@ func (c *svc2JSONClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { } func (c *svc2JSONClient) SamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { + caller := c.callSamePackageProtoImport + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") + } + return c.callSamePackageProtoImport(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2JSONClient) callSamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") @@ -324,7 +432,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg2) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor") } return s.Svc2.Send(ctx, typedReq) }, @@ -332,7 +440,7 @@ func (s *svc2Server) serveSendJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg2) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor") } return typedResp, err } @@ -406,7 +514,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg2) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg2) when calling interceptor") } return s.Svc2.Send(ctx, typedReq) }, @@ -414,7 +522,7 @@ func (s *svc2Server) serveSendProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg2) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg2) when calling interceptor") } return typedResp, err } @@ -499,7 +607,7 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") } return s.Svc2.SamePackageProtoImport(ctx, typedReq) }, @@ -507,7 +615,7 @@ func (s *svc2Server) serveSamePackageProtoImportJSON(ctx context.Context, resp h if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") } return typedResp, err } @@ -581,7 +689,7 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg1) when calling interceptor") } return s.Svc2.SamePackageProtoImport(ctx, typedReq) }, @@ -589,7 +697,7 @@ func (s *svc2Server) serveSamePackageProtoImportProtobuf(ctx context.Context, re if resp != nil { typedResp, ok := resp.(*Msg1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg1) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 476538fa..14f95a76 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -48,9 +48,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -73,13 +74,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -104,9 +132,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -129,13 +158,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -290,7 +346,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -298,7 +354,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -372,7 +428,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -380,7 +436,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index 323a88ca..e3d3bba6 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -50,9 +50,10 @@ type Svc2 interface { // ==================== type svc2ProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2ProtobufClient creates a Protobuf client that implements the Svc2 interface. @@ -75,13 +76,40 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &svc2ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2ProtobufClient) Method(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2ProtobufClient) callMethod(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -106,9 +134,10 @@ func (c *svc2ProtobufClient) Method(ctx context.Context, in *no_package_name.Msg // ================ type svc2JSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvc2JSONClient creates a JSON client that implements the Svc2 interface. @@ -131,13 +160,40 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &svc2JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svc2JSONClient) Method(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*no_package_name.Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svc2JSONClient) callMethod(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Svc2") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -292,7 +348,7 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor") } return s.Svc2.Method(ctx, typedReq) }, @@ -300,7 +356,7 @@ func (s *svc2Server) serveMethodJSON(ctx context.Context, resp http.ResponseWrit if resp != nil { typedResp, ok := resp.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor") } return typedResp, err } @@ -374,7 +430,7 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*no_package_name.Msg) when calling interceptor") } return s.Svc2.Method(ctx, typedReq) }, @@ -382,7 +438,7 @@ func (s *svc2Server) serveMethodProtobuf(ctx context.Context, resp http.Response if resp != nil { typedResp, ok := resp.(*no_package_name.Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*no_package_name.Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index f1b6cf1f..1452b2ad 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -51,9 +51,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -76,13 +77,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -107,9 +135,10 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -132,13 +161,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callSend + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callSend(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Send") @@ -293,7 +349,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -301,7 +357,7 @@ func (s *svcServer) serveSendJSON(ctx context.Context, resp http.ResponseWriter, if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -375,7 +431,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Send(ctx, typedReq) }, @@ -383,7 +439,7 @@ func (s *svcServer) serveSendProtobuf(ctx context.Context, resp http.ResponseWri if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 58ce4491..452110ab 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -50,9 +50,10 @@ type Haberdasher interface { // =========================== type haberdasherProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherProtobufClient creates a Protobuf client that implements the Haberdasher interface. @@ -75,13 +76,40 @@ func NewHaberdasherProtobufClient(baseURL string, client HTTPClient, opts ...twi } return &haberdasherProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + caller := c.callMakeHat + if c.interceptor != nil { + caller = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") + } + return c.callMakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherProtobufClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHat") @@ -106,9 +134,10 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat // ======================= type haberdasherJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherJSONClient creates a JSON client that implements the Haberdasher interface. @@ -131,13 +160,40 @@ func NewHaberdasherJSONClient(baseURL string, client HTTPClient, opts ...twirp.C } return &haberdasherJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + caller := c.callMakeHat + if c.interceptor != nil { + caller = func(ctx context.Context, req *Size) (*Hat, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Size) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") + } + return c.callMakeHat(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Hat) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherJSONClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") ctx = ctxsetters.WithMethodName(ctx, "MakeHat") @@ -292,7 +348,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -300,7 +356,7 @@ func (s *haberdasherServer) serveMakeHatJSON(ctx context.Context, resp http.Resp if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") } return typedResp, err } @@ -374,7 +430,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Size) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Size) when calling interceptor") } return s.Haberdasher.MakeHat(ctx, typedReq) }, @@ -382,7 +438,7 @@ func (s *haberdasherServer) serveMakeHatProtobuf(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*Hat) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Hat) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go index b0dc7391..f382471a 100644 --- a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go +++ b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go @@ -48,9 +48,10 @@ type Echo interface { // ==================== type echoProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewEchoProtobufClient creates a Protobuf client that implements the Echo interface. @@ -73,13 +74,40 @@ func NewEchoProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } return &echoProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *echoProtobufClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callEcho + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callEcho(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *echoProtobufClient) callEcho(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Echo") ctx = ctxsetters.WithMethodName(ctx, "Echo") @@ -104,9 +132,10 @@ func (c *echoProtobufClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { // ================ type echoJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewEchoJSONClient creates a JSON client that implements the Echo interface. @@ -129,13 +158,40 @@ func NewEchoJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } return &echoJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *echoJSONClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callEcho + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callEcho(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *echoJSONClient) callEcho(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "") ctx = ctxsetters.WithServiceName(ctx, "Echo") ctx = ctxsetters.WithMethodName(ctx, "Echo") @@ -290,7 +346,7 @@ func (s *echoServer) serveEchoJSON(ctx context.Context, resp http.ResponseWriter func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Echo.Echo(ctx, typedReq) }, @@ -298,7 +354,7 @@ func (s *echoServer) serveEchoJSON(ctx context.Context, resp http.ResponseWriter if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -372,7 +428,7 @@ func (s *echoServer) serveEchoProtobuf(ctx context.Context, resp http.ResponseWr func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Echo.Echo(ctx, typedReq) }, @@ -380,7 +436,7 @@ func (s *echoServer) serveEchoProtobuf(ctx context.Context, resp http.ResponseWr if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/snake_case_names/snake_case_names.twirp.go b/internal/twirptest/snake_case_names/snake_case_names.twirp.go index 9b866322..f7eb8f84 100644 --- a/internal/twirptest/snake_case_names/snake_case_names.twirp.go +++ b/internal/twirptest/snake_case_names/snake_case_names.twirp.go @@ -53,9 +53,10 @@ type HaberdasherV1 interface { // ============================= type haberdasherV1ProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherV1ProtobufClient creates a Protobuf client that implements the HaberdasherV1 interface. @@ -78,13 +79,40 @@ func NewHaberdasherV1ProtobufClient(baseURL string, client HTTPClient, opts ...t } return &haberdasherV1ProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherV1ProtobufClient) MakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + caller := c.callMakeHatV1 + if c.interceptor != nil { + caller = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*MakeHatArgsV1_SizeV1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor") + } + return c.callMakeHatV1(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*MakeHatArgsV1_HatV1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherV1ProtobufClient) callMakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") @@ -109,9 +137,10 @@ func (c *haberdasherV1ProtobufClient) MakeHatV1(ctx context.Context, in *MakeHat // ========================= type haberdasherV1JSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewHaberdasherV1JSONClient creates a JSON client that implements the HaberdasherV1 interface. @@ -134,13 +163,40 @@ func NewHaberdasherV1JSONClient(baseURL string, client HTTPClient, opts ...twirp } return &haberdasherV1JSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *haberdasherV1JSONClient) MakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + caller := c.callMakeHatV1 + if c.interceptor != nil { + caller = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*MakeHatArgsV1_SizeV1) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor") + } + return c.callMakeHatV1(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*MakeHatArgsV1_HatV1) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *haberdasherV1JSONClient) callMakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") @@ -295,7 +351,7 @@ func (s *haberdasherV1Server) serveMakeHatV1JSON(ctx context.Context, resp http. func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*MakeHatArgsV1_SizeV1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor") } return s.HaberdasherV1.MakeHatV1(ctx, typedReq) }, @@ -303,7 +359,7 @@ func (s *haberdasherV1Server) serveMakeHatV1JSON(ctx context.Context, resp http. if resp != nil { typedResp, ok := resp.(*MakeHatArgsV1_HatV1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor") } return typedResp, err } @@ -377,7 +433,7 @@ func (s *haberdasherV1Server) serveMakeHatV1Protobuf(ctx context.Context, resp h func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*MakeHatArgsV1_SizeV1) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*MakeHatArgsV1_SizeV1) when calling interceptor") } return s.HaberdasherV1.MakeHatV1(ctx, typedReq) }, @@ -385,7 +441,7 @@ func (s *haberdasherV1Server) serveMakeHatV1Protobuf(ctx context.Context, resp h if resp != nil { typedResp, ok := resp.(*MakeHatArgsV1_HatV1) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*MakeHatArgsV1_HatV1) when calling interceptor") } return typedResp, err } diff --git a/internal/twirptest/source_relative/source_relative.twirp.go b/internal/twirptest/source_relative/source_relative.twirp.go index fd6de8c7..b17d3f14 100644 --- a/internal/twirptest/source_relative/source_relative.twirp.go +++ b/internal/twirptest/source_relative/source_relative.twirp.go @@ -48,9 +48,10 @@ type Svc interface { // =================== type svcProtobufClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcProtobufClient creates a Protobuf client that implements the Svc interface. @@ -73,13 +74,40 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } return &svcProtobufClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcProtobufClient) Method(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcProtobufClient) callMethod(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -104,9 +132,10 @@ func (c *svcProtobufClient) Method(ctx context.Context, in *Msg) (*Msg, error) { // =============== type svcJSONClient struct { - client HTTPClient - urls [1]string - opts twirp.ClientOptions + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions } // NewSvcJSONClient creates a JSON client that implements the Svc interface. @@ -129,13 +158,40 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } return &svcJSONClient{ - client: client, - urls: urls, - opts: clientOpts, + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, } } func (c *svcJSONClient) Method(ctx context.Context, in *Msg) (*Msg, error) { + caller := c.callMethod + if c.interceptor != nil { + caller = func(ctx context.Context, req *Msg) (*Msg, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") + } + return c.callMethod(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*Msg) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *svcJSONClient) callMethod(ctx context.Context, in *Msg) (*Msg, error) { ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") ctx = ctxsetters.WithServiceName(ctx, "Svc") ctx = ctxsetters.WithMethodName(ctx, "Method") @@ -290,7 +346,7 @@ func (s *svcServer) serveMethodJSON(ctx context.Context, resp http.ResponseWrite func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Method(ctx, typedReq) }, @@ -298,7 +354,7 @@ func (s *svcServer) serveMethodJSON(ctx context.Context, resp http.ResponseWrite if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } @@ -372,7 +428,7 @@ func (s *svcServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseW func(ctx context.Context, req interface{}) (interface{}, error) { typedReq, ok := req.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion req.(*Msg) when calling interceptor") } return s.Svc.Method(ctx, typedReq) }, @@ -380,7 +436,7 @@ func (s *svcServer) serveMethodProtobuf(ctx context.Context, resp http.ResponseW if resp != nil { typedResp, ok := resp.(*Msg) if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor handler") + return nil, twirp.InternalError("failed type assertion resp.(*Msg) when calling interceptor") } return typedResp, err } diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index 938e2585..856f14b1 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -969,6 +969,7 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto t.P(`type `, structName, ` struct {`) t.P(` client HTTPClient`) t.P(` urls [`, methCnt, `]string`) + t.P(` interceptor `, t.pkgs["twirp"], `.Interceptor`) t.P(` opts `, t.pkgs["twirp"], `.ClientOptions`) t.P(`}`) t.P() @@ -998,6 +999,7 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto t.P(` return &`, structName, `{`) t.P(` client: client,`) t.P(` urls: urls,`) + t.P(` interceptor: `, t.pkgs["twirp"], `.ChainInterceptors(clientOpts.Interceptors...),`) t.P(` opts: clientOpts,`) t.P(` }`) t.P(`}`) @@ -1008,8 +1010,15 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto pkgName := pkgName(file) inputType := t.goTypeName(method.GetInputType()) outputType := t.goTypeName(method.GetOutputType()) - t.P(`func (c *`, structName, `) `, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `) (*`, outputType, `, error) {`) + t.P(` caller := c.call`, methName) + t.P(` if c.interceptor != nil {`) + t.generateClientInterceptorCaller(method) + t.P(` }`) + t.P(` return caller(ctx, in)`) + t.P(`}`) + t.P() + t.P(`func (c *`, structName, `) call`, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `) (*`, outputType, `, error) {`) t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithPackageName(ctx, "`, pkgName, `")`) t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithServiceName(ctx, "`, servName, `")`) t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) @@ -1347,25 +1356,39 @@ func (t *twirp) generateServerProtobufMethod(service *descriptor.ServiceDescript t.P() } +func (t *twirp) generateClientInterceptorCaller(method *descriptor.MethodDescriptorProto) { + methName := methodNameCamelCased(method) + t.generateInterceptorFunc("c", "caller", "c.call"+methName, method) +} + func (t *twirp) generateServerInterceptorHandler(service *descriptor.ServiceDescriptorProto, method *descriptor.MethodDescriptorProto) { methName := methodNameCamelCased(method) servName := serviceNameCamelCased(service) + t.generateInterceptorFunc("s", "handler", "s."+servName+"."+methName, method) +} + +func (t *twirp) generateInterceptorFunc( + receiverName string, + varName string, + delegateFuncName string, + method *descriptor.MethodDescriptorProto, +) { inputType := t.goTypeName(method.GetInputType()) outputType := t.goTypeName(method.GetOutputType()) - t.P(` handler = func(ctx `, t.pkgs["context"], `.Context, req *`, inputType, `) (*`, outputType, `, error) {`) - t.P(` resp, err := s.interceptor(`) + t.P(` `, varName, ` = func(ctx `, t.pkgs["context"], `.Context, req *`, inputType, `) (*`, outputType, `, error) {`) + t.P(` resp, err := `, receiverName, `.interceptor(`) t.P(` func(ctx `, t.pkgs["context"], ` .Context, req interface{}) (interface{}, error) {`) t.P(` typedReq, ok := req.(*`, inputType, `)`) t.P(` if !ok {`) - t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion req.(*`, inputType, `) when calling interceptor handler")`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion req.(*`, inputType, `) when calling interceptor")`) t.P(` }`) - t.P(` return s.`, servName, `.`, methName, `(ctx, typedReq)`) + t.P(` return `, delegateFuncName, `(ctx, typedReq)`) t.P(` },`) t.P(` )(ctx, req)`) t.P(` if resp != nil {`) t.P(` typedResp, ok := resp.(*`, outputType, `)`) t.P(` if !ok {`) - t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion resp.(*`, outputType, `) when calling interceptor handler")`) + t.P(` return nil, `, t.pkgs["twirp"], `.InternalError("failed type assertion resp.(*`, outputType, `) when calling interceptor")`) t.P(` }`) t.P(` return typedResp, err`) t.P(` }`) From ae97c56db77f1c3a9ac6b9b60ceb8db2fadffbea Mon Sep 17 00:00:00 2001 From: bufdev Date: Fri, 18 Sep 2020 13:01:34 -0400 Subject: [PATCH 6/6] Add client interceptors --- .../clientcompat/clientcompat.twirp.go | 24 ++++---- docs/interceptors.md | 8 ++- example/service.twirp.go | 12 ++-- internal/twirptest/client_test.go | 61 +++++++++++++++++++ .../twirptest/gogo_compat/service.twirp.go | 12 ++-- .../google_protobuf_imports/service.twirp.go | 12 ++-- .../twirptest/importable/importable.twirp.go | 12 ++-- internal/twirptest/importer/importer.twirp.go | 12 ++-- .../importer_local/importer_local.twirp.go | 12 ++-- internal/twirptest/importmapping/x/x.twirp.go | 12 ++-- .../json_serialization.twirp.go | 12 ++-- .../twirptest/multiple/multiple1.twirp.go | 12 ++-- .../twirptest/multiple/multiple2.twirp.go | 24 ++++---- .../no_package_name/no_package_name.twirp.go | 12 ++-- .../no_package_name_importer.twirp.go | 12 ++-- internal/twirptest/proto/proto.twirp.go | 12 ++-- internal/twirptest/service.twirp.go | 12 ++-- .../service_method_same_name.twirp.go | 12 ++-- internal/twirptest/service_test.go | 1 - .../snake_case_names.twirp.go | 12 ++-- .../source_relative/source_relative.twirp.go | 12 ++-- protoc-gen-twirp/generator.go | 6 +- 22 files changed, 190 insertions(+), 126 deletions(-) diff --git a/clientcompat/internal/clientcompat/clientcompat.twirp.go b/clientcompat/internal/clientcompat/clientcompat.twirp.go index e403bb94..4acf117b 100644 --- a/clientcompat/internal/clientcompat/clientcompat.twirp.go +++ b/clientcompat/internal/clientcompat/clientcompat.twirp.go @@ -85,6 +85,9 @@ func NewCompatServiceProtobufClient(baseURL string, client HTTPClient, opts ...t } func (c *compatServiceProtobufClient) Method(ctx context.Context, in *Req) (*Resp, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") + ctx = ctxsetters.WithServiceName(ctx, "CompatService") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Req) (*Resp, error) { @@ -111,9 +114,6 @@ func (c *compatServiceProtobufClient) Method(ctx context.Context, in *Req) (*Res } func (c *compatServiceProtobufClient) callMethod(ctx context.Context, in *Req) (*Resp, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") - ctx = ctxsetters.WithServiceName(ctx, "CompatService") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(Resp) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -131,6 +131,9 @@ func (c *compatServiceProtobufClient) callMethod(ctx context.Context, in *Req) ( } func (c *compatServiceProtobufClient) NoopMethod(ctx context.Context, in *Empty) (*Empty, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") + ctx = ctxsetters.WithServiceName(ctx, "CompatService") + ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") caller := c.callNoopMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Empty) (*Empty, error) { @@ -157,9 +160,6 @@ func (c *compatServiceProtobufClient) NoopMethod(ctx context.Context, in *Empty) } func (c *compatServiceProtobufClient) callNoopMethod(ctx context.Context, in *Empty) (*Empty, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") - ctx = ctxsetters.WithServiceName(ctx, "CompatService") - ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") out := new(Empty) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[1], in, out) if err != nil { @@ -216,6 +216,9 @@ func NewCompatServiceJSONClient(baseURL string, client HTTPClient, opts ...twirp } func (c *compatServiceJSONClient) Method(ctx context.Context, in *Req) (*Resp, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") + ctx = ctxsetters.WithServiceName(ctx, "CompatService") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Req) (*Resp, error) { @@ -242,9 +245,6 @@ func (c *compatServiceJSONClient) Method(ctx context.Context, in *Req) (*Resp, e } func (c *compatServiceJSONClient) callMethod(ctx context.Context, in *Req) (*Resp, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") - ctx = ctxsetters.WithServiceName(ctx, "CompatService") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(Resp) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -262,6 +262,9 @@ func (c *compatServiceJSONClient) callMethod(ctx context.Context, in *Req) (*Res } func (c *compatServiceJSONClient) NoopMethod(ctx context.Context, in *Empty) (*Empty, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") + ctx = ctxsetters.WithServiceName(ctx, "CompatService") + ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") caller := c.callNoopMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Empty) (*Empty, error) { @@ -288,9 +291,6 @@ func (c *compatServiceJSONClient) NoopMethod(ctx context.Context, in *Empty) (*E } func (c *compatServiceJSONClient) callNoopMethod(ctx context.Context, in *Empty) (*Empty, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.clientcompat") - ctx = ctxsetters.WithServiceName(ctx, "CompatService") - ctx = ctxsetters.WithMethodName(ctx, "NoopMethod") out := new(Empty) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[1], in, out) if err != nil { diff --git a/docs/interceptors.md b/docs/interceptors.md index 895ec544..a86c130d 100644 --- a/docs/interceptors.md +++ b/docs/interceptors.md @@ -4,11 +4,15 @@ title: "Interceptors" sidebar_label: "Interceptors" --- -The service constructor can use the option `twirp.WithServerInterceptors(interceptors ...twirp.Interceptor)` +The client and service constructors can use the options +`twirp.WithClientInterceptors(interceptors ...twirp.Interceptor)` +and `twirp.WithServerInterceptors(interceptors ...twirp.Interceptor)` to plug in additional functionality: ```go -server := NewHaberdasherServer(svcImpl, twirp.WithInterceptor(NewLogInterceptor(logger.New(os.Stderr, "", 0)))) +client := NewHaberdasherProtobufClient(url, &http.Client{}, twirp.WithClientInterceptors(NewLogInterceptor(logger.New(os.Stderr, "", 0)))) + +server := NewHaberdasherServer(svcImpl, twirp.WithServerInterceptors(NewLogInterceptor(logger.New(os.Stderr, "", 0)))) // NewLogInterceptor logs various parts of a request using a standard Logger. func NewLogInterceptor(l *log.Logger) twirp.Interceptor { diff --git a/example/service.twirp.go b/example/service.twirp.go index bd1497b2..62b29cc1 100644 --- a/example/service.twirp.go +++ b/example/service.twirp.go @@ -84,6 +84,9 @@ func NewHaberdasherProtobufClient(baseURL string, client HTTPClient, opts ...twi } func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") caller := c.callMakeHat if c.interceptor != nil { caller = func(ctx context.Context, req *Size) (*Hat, error) { @@ -110,9 +113,6 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat } func (c *haberdasherProtobufClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { - ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") - ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") out := new(Hat) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -168,6 +168,9 @@ func NewHaberdasherJSONClient(baseURL string, client HTTPClient, opts ...twirp.C } func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") caller := c.callMakeHat if c.interceptor != nil { caller = func(ctx context.Context, req *Size) (*Hat, error) { @@ -194,9 +197,6 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er } func (c *haberdasherJSONClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { - ctx = ctxsetters.WithPackageName(ctx, "twitch.twirp.example") - ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") out := new(Hat) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/client_test.go b/internal/twirptest/client_test.go index ecdb9e0c..97224110 100644 --- a/internal/twirptest/client_test.go +++ b/internal/twirptest/client_test.go @@ -402,6 +402,67 @@ func TestClientContextToHook(t *testing.T) { } } +func TestClientInterceptor(t *testing.T) { + interceptor := func(next twirp.Method) twirp.Method { + return func(ctx context.Context, request interface{}) (interface{}, error) { + methodName, _ := twirp.MethodName(ctx) + if methodName != "MakeHat" { + return nil, fmt.Errorf("unexpected methodName: %q", methodName) + } + serviceName, _ := twirp.ServiceName(ctx) + if serviceName != "Haberdasher" { + return nil, fmt.Errorf("unexpected serviceName: %q", serviceName) + } + packageName, _ := twirp.PackageName(ctx) + if packageName != "twirp.internal.twirptest" { + return nil, fmt.Errorf("unexpected packageName: %q", packageName) + } + size, ok := request.(*Size) + if !ok { + return nil, fmt.Errorf("could not cast %T to a *Size", request) + } + size.Inches = size.Inches + 1 + response, err := next(ctx, request) + hat, ok := response.(*Hat) + if ok && hat != nil { + hat.Color = hat.Color + "x" + return hat, err + } + return nil, err + } + } + h := PickyHatmaker(3) + + s := httptest.NewServer(NewHaberdasherServer(h)) + defer s.Close() + client := NewHaberdasherProtobufClient( + s.URL, + http.DefaultClient, + twirp.WithClientInterceptors( + interceptor, + interceptor, + ), + ) + hat, clientErr := client.MakeHat(context.Background(), &Size{Inches: 1}) + if clientErr != nil { + t.Fatalf("client err=%q", clientErr) + } + if hat.Size != 3 { + t.Errorf("hat size expected=3 actual=%v", hat.Size) + } + if hat.Color != "bluexx" { + t.Errorf("hat color expected=bluexx actual=%v", hat.Color) + } + _, clientErr = client.MakeHat(context.Background(), &Size{Inches: 3}) + twerr, ok := clientErr.(twirp.Error) + if !ok { + t.Fatalf("expected twirp.Error type error, have %T", clientErr) + } + if twerr.Code() != twirp.InvalidArgument { + t.Errorf("expected error type to be InvalidArgument, buf found %q", twerr.Code()) + } +} + func TestClientIntermediaryErrors(t *testing.T) { testcase := func(body string, code int, expectedErrorCode twirp.ErrorCode, clientMaker func(string, HTTPClient, ...twirp.ClientOption) Haberdasher) func(*testing.T) { return func(t *testing.T) { diff --git a/internal/twirptest/gogo_compat/service.twirp.go b/internal/twirptest/gogo_compat/service.twirp.go index 1e742bc9..87f2d3a6 100644 --- a/internal/twirptest/gogo_compat/service.twirp.go +++ b/internal/twirptest/gogo_compat/service.twirp.go @@ -86,6 +86,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -112,9 +115,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -170,6 +170,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -196,9 +199,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.gogo_compat") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/google_protobuf_imports/service.twirp.go b/internal/twirptest/google_protobuf_imports/service.twirp.go index 2c7472b0..6579c187 100644 --- a/internal/twirptest/google_protobuf_imports/service.twirp.go +++ b/internal/twirptest/google_protobuf_imports/service.twirp.go @@ -85,6 +85,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { @@ -111,9 +114,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *google_protobuf1.Strin } func (c *svcProtobufClient) callSend(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(google_protobuf.Empty) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -169,6 +169,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { @@ -195,9 +198,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *google_protobuf1.StringVal } func (c *svcJSONClient) callSend(ctx context.Context, in *google_protobuf1.StringValue) (*google_protobuf.Empty, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.use_empty") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(google_protobuf.Empty) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/importable/importable.twirp.go b/internal/twirptest/importable/importable.twirp.go index e30ccb47..25c0a5c1 100644 --- a/internal/twirptest/importable/importable.twirp.go +++ b/internal/twirptest/importable/importable.twirp.go @@ -85,6 +85,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -111,9 +114,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -169,6 +169,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -195,9 +198,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importable") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/importer/importer.twirp.go b/internal/twirptest/importer/importer.twirp.go index bc887ccf..d2a05716 100644 --- a/internal/twirptest/importer/importer.twirp.go +++ b/internal/twirptest/importer/importer.twirp.go @@ -87,6 +87,9 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *svc2ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { @@ -113,9 +116,6 @@ func (c *svc2ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirpt } func (c *svc2ProtobufClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(twirp_internal_twirptest_importable.Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -171,6 +171,9 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *svc2JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { @@ -197,9 +200,6 @@ func (c *svc2JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_ } func (c *svc2JSONClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importable.Msg) (*twirp_internal_twirptest_importable.Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(twirp_internal_twirptest_importable.Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/importer_local/importer_local.twirp.go b/internal/twirptest/importer_local/importer_local.twirp.go index 148681b8..119c5123 100644 --- a/internal/twirptest/importer_local/importer_local.twirp.go +++ b/internal/twirptest/importer_local/importer_local.twirp.go @@ -82,6 +82,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -108,9 +111,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -166,6 +166,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -192,9 +195,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importer_local") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/importmapping/x/x.twirp.go b/internal/twirptest/importmapping/x/x.twirp.go index ce1ac527..106e3d3d 100644 --- a/internal/twirptest/importmapping/x/x.twirp.go +++ b/internal/twirptest/importmapping/x/x.twirp.go @@ -84,6 +84,9 @@ func NewSvc1ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *svc1ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") + ctx = ctxsetters.WithServiceName(ctx, "Svc1") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { @@ -110,9 +113,6 @@ func (c *svc1ProtobufClient) Send(ctx context.Context, in *twirp_internal_twirpt } func (c *svc1ProtobufClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") - ctx = ctxsetters.WithServiceName(ctx, "Svc1") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(twirp_internal_twirptest_importmapping_y.MsgY) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -168,6 +168,9 @@ func NewSvc1JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *svc1JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") + ctx = ctxsetters.WithServiceName(ctx, "Svc1") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { @@ -194,9 +197,6 @@ func (c *svc1JSONClient) Send(ctx context.Context, in *twirp_internal_twirptest_ } func (c *svc1JSONClient) callSend(ctx context.Context, in *twirp_internal_twirptest_importmapping_y.MsgY) (*twirp_internal_twirptest_importmapping_y.MsgY, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.importmapping.x") - ctx = ctxsetters.WithServiceName(ctx, "Svc1") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(twirp_internal_twirptest_importmapping_y.MsgY) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/json_serialization/json_serialization.twirp.go b/internal/twirptest/json_serialization/json_serialization.twirp.go index 14aaeec1..640b288c 100644 --- a/internal/twirptest/json_serialization/json_serialization.twirp.go +++ b/internal/twirptest/json_serialization/json_serialization.twirp.go @@ -82,6 +82,9 @@ func NewJSONSerializationProtobufClient(baseURL string, client HTTPClient, opts } func (c *jSONSerializationProtobufClient) EchoJSON(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") + ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") caller := c.callEchoJSON if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -108,9 +111,6 @@ func (c *jSONSerializationProtobufClient) EchoJSON(ctx context.Context, in *Msg) } func (c *jSONSerializationProtobufClient) callEchoJSON(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") - ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -166,6 +166,9 @@ func NewJSONSerializationJSONClient(baseURL string, client HTTPClient, opts ...t } func (c *jSONSerializationJSONClient) EchoJSON(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") + ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") caller := c.callEchoJSON if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -192,9 +195,6 @@ func (c *jSONSerializationJSONClient) EchoJSON(ctx context.Context, in *Msg) (*M } func (c *jSONSerializationJSONClient) callEchoJSON(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "JSONSerialization") - ctx = ctxsetters.WithMethodName(ctx, "EchoJSON") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/multiple/multiple1.twirp.go b/internal/twirptest/multiple/multiple1.twirp.go index 646475cf..aa00b581 100644 --- a/internal/twirptest/multiple/multiple1.twirp.go +++ b/internal/twirptest/multiple/multiple1.twirp.go @@ -86,6 +86,9 @@ func NewSvc1ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *svc1ProtobufClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc1") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { @@ -112,9 +115,6 @@ func (c *svc1ProtobufClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) } func (c *svc1ProtobufClient) callSend(ctx context.Context, in *Msg1) (*Msg1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc1") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg1) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -170,6 +170,9 @@ func NewSvc1JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *svc1JSONClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc1") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { @@ -196,9 +199,6 @@ func (c *svc1JSONClient) Send(ctx context.Context, in *Msg1) (*Msg1, error) { } func (c *svc1JSONClient) callSend(ctx context.Context, in *Msg1) (*Msg1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc1") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg1) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/multiple/multiple2.twirp.go b/internal/twirptest/multiple/multiple2.twirp.go index bf55d2ae..644bddf9 100644 --- a/internal/twirptest/multiple/multiple2.twirp.go +++ b/internal/twirptest/multiple/multiple2.twirp.go @@ -72,6 +72,9 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *svc2ProtobufClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg2) (*Msg2, error) { @@ -98,9 +101,6 @@ func (c *svc2ProtobufClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) } func (c *svc2ProtobufClient) callSend(ctx context.Context, in *Msg2) (*Msg2, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg2) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -118,6 +118,9 @@ func (c *svc2ProtobufClient) callSend(ctx context.Context, in *Msg2) (*Msg2, err } func (c *svc2ProtobufClient) SamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") caller := c.callSamePackageProtoImport if c.interceptor != nil { caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { @@ -144,9 +147,6 @@ func (c *svc2ProtobufClient) SamePackageProtoImport(ctx context.Context, in *Msg } func (c *svc2ProtobufClient) callSamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") out := new(Msg1) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[1], in, out) if err != nil { @@ -203,6 +203,9 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *svc2JSONClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg2) (*Msg2, error) { @@ -229,9 +232,6 @@ func (c *svc2JSONClient) Send(ctx context.Context, in *Msg2) (*Msg2, error) { } func (c *svc2JSONClient) callSend(ctx context.Context, in *Msg2) (*Msg2, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg2) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -249,6 +249,9 @@ func (c *svc2JSONClient) callSend(ctx context.Context, in *Msg2) (*Msg2, error) } func (c *svc2JSONClient) SamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") caller := c.callSamePackageProtoImport if c.interceptor != nil { caller = func(ctx context.Context, req *Msg1) (*Msg1, error) { @@ -275,9 +278,6 @@ func (c *svc2JSONClient) SamePackageProtoImport(ctx context.Context, in *Msg1) ( } func (c *svc2JSONClient) callSamePackageProtoImport(ctx context.Context, in *Msg1) (*Msg1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.multiple") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "SamePackageProtoImport") out := new(Msg1) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[1], in, out) if err != nil { diff --git a/internal/twirptest/no_package_name/no_package_name.twirp.go b/internal/twirptest/no_package_name/no_package_name.twirp.go index 14f95a76..202122fd 100644 --- a/internal/twirptest/no_package_name/no_package_name.twirp.go +++ b/internal/twirptest/no_package_name/no_package_name.twirp.go @@ -82,6 +82,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -108,9 +111,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -166,6 +166,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -192,9 +195,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go index e3d3bba6..8708d225 100644 --- a/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go +++ b/internal/twirptest/no_package_name_importer/no_package_name_importer.twirp.go @@ -84,6 +84,9 @@ func NewSvc2ProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *svc2ProtobufClient) Method(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { @@ -110,9 +113,6 @@ func (c *svc2ProtobufClient) Method(ctx context.Context, in *no_package_name.Msg } func (c *svc2ProtobufClient) callMethod(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(no_package_name.Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -168,6 +168,9 @@ func NewSvc2JSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *svc2JSONClient) Method(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Svc2") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *no_package_name.Msg) (*no_package_name.Msg, error) { @@ -194,9 +197,6 @@ func (c *svc2JSONClient) Method(ctx context.Context, in *no_package_name.Msg) (* } func (c *svc2JSONClient) callMethod(ctx context.Context, in *no_package_name.Msg) (*no_package_name.Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Svc2") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(no_package_name.Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/proto/proto.twirp.go b/internal/twirptest/proto/proto.twirp.go index 1452b2ad..4f3393e0 100644 --- a/internal/twirptest/proto/proto.twirp.go +++ b/internal/twirptest/proto/proto.twirp.go @@ -85,6 +85,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -111,9 +114,6 @@ func (c *svcProtobufClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -169,6 +169,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Send") caller := c.callSend if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -195,9 +198,6 @@ func (c *svcJSONClient) Send(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callSend(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.proto") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Send") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/service.twirp.go b/internal/twirptest/service.twirp.go index 452110ab..1714de19 100644 --- a/internal/twirptest/service.twirp.go +++ b/internal/twirptest/service.twirp.go @@ -84,6 +84,9 @@ func NewHaberdasherProtobufClient(baseURL string, client HTTPClient, opts ...twi } func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") caller := c.callMakeHat if c.interceptor != nil { caller = func(ctx context.Context, req *Size) (*Hat, error) { @@ -110,9 +113,6 @@ func (c *haberdasherProtobufClient) MakeHat(ctx context.Context, in *Size) (*Hat } func (c *haberdasherProtobufClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") - ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") out := new(Hat) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -168,6 +168,9 @@ func NewHaberdasherJSONClient(baseURL string, client HTTPClient, opts ...twirp.C } func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") + ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") + ctx = ctxsetters.WithMethodName(ctx, "MakeHat") caller := c.callMakeHat if c.interceptor != nil { caller = func(ctx context.Context, req *Size) (*Hat, error) { @@ -194,9 +197,6 @@ func (c *haberdasherJSONClient) MakeHat(ctx context.Context, in *Size) (*Hat, er } func (c *haberdasherJSONClient) callMakeHat(ctx context.Context, in *Size) (*Hat, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest") - ctx = ctxsetters.WithServiceName(ctx, "Haberdasher") - ctx = ctxsetters.WithMethodName(ctx, "MakeHat") out := new(Hat) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go index f382471a..f8ba2a55 100644 --- a/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go +++ b/internal/twirptest/service_method_same_name/service_method_same_name.twirp.go @@ -82,6 +82,9 @@ func NewEchoProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clie } func (c *echoProtobufClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Echo") + ctx = ctxsetters.WithMethodName(ctx, "Echo") caller := c.callEcho if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -108,9 +111,6 @@ func (c *echoProtobufClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { } func (c *echoProtobufClient) callEcho(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Echo") - ctx = ctxsetters.WithMethodName(ctx, "Echo") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -166,6 +166,9 @@ func NewEchoJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOp } func (c *echoJSONClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "") + ctx = ctxsetters.WithServiceName(ctx, "Echo") + ctx = ctxsetters.WithMethodName(ctx, "Echo") caller := c.callEcho if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -192,9 +195,6 @@ func (c *echoJSONClient) Echo(ctx context.Context, in *Msg) (*Msg, error) { } func (c *echoJSONClient) callEcho(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "") - ctx = ctxsetters.WithServiceName(ctx, "Echo") - ctx = ctxsetters.WithMethodName(ctx, "Echo") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/service_test.go b/internal/twirptest/service_test.go index 0c7e888a..8b664c92 100644 --- a/internal/twirptest/service_test.go +++ b/internal/twirptest/service_test.go @@ -643,7 +643,6 @@ func TestInterceptor(t *testing.T) { if !ok { t.Fatalf("expected twirp.Error type error, have %T", clientErr) } - if twerr.Code() != twirp.InvalidArgument { t.Errorf("expected error type to be InvalidArgument, buf found %q", twerr.Code()) } diff --git a/internal/twirptest/snake_case_names/snake_case_names.twirp.go b/internal/twirptest/snake_case_names/snake_case_names.twirp.go index f7eb8f84..49c5703e 100644 --- a/internal/twirptest/snake_case_names/snake_case_names.twirp.go +++ b/internal/twirptest/snake_case_names/snake_case_names.twirp.go @@ -87,6 +87,9 @@ func NewHaberdasherV1ProtobufClient(baseURL string, client HTTPClient, opts ...t } func (c *haberdasherV1ProtobufClient) MakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") + ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") + ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") caller := c.callMakeHatV1 if c.interceptor != nil { caller = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { @@ -113,9 +116,6 @@ func (c *haberdasherV1ProtobufClient) MakeHatV1(ctx context.Context, in *MakeHat } func (c *haberdasherV1ProtobufClient) callMakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") - ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") - ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") out := new(MakeHatArgsV1_HatV1) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -171,6 +171,9 @@ func NewHaberdasherV1JSONClient(baseURL string, client HTTPClient, opts ...twirp } func (c *haberdasherV1JSONClient) MakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") + ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") + ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") caller := c.callMakeHatV1 if c.interceptor != nil { caller = func(ctx context.Context, req *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { @@ -197,9 +200,6 @@ func (c *haberdasherV1JSONClient) MakeHatV1(ctx context.Context, in *MakeHatArgs } func (c *haberdasherV1JSONClient) callMakeHatV1(ctx context.Context, in *MakeHatArgsV1_SizeV1) (*MakeHatArgsV1_HatV1, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.snake_case_names") - ctx = ctxsetters.WithServiceName(ctx, "HaberdasherV1") - ctx = ctxsetters.WithMethodName(ctx, "MakeHatV1") out := new(MakeHatArgsV1_HatV1) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/internal/twirptest/source_relative/source_relative.twirp.go b/internal/twirptest/source_relative/source_relative.twirp.go index b17d3f14..755e9609 100644 --- a/internal/twirptest/source_relative/source_relative.twirp.go +++ b/internal/twirptest/source_relative/source_relative.twirp.go @@ -82,6 +82,9 @@ func NewSvcProtobufClient(baseURL string, client HTTPClient, opts ...twirp.Clien } func (c *svcProtobufClient) Method(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -108,9 +111,6 @@ func (c *svcProtobufClient) Method(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcProtobufClient) callMethod(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(Msg) ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { @@ -166,6 +166,9 @@ func NewSvcJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOpt } func (c *svcJSONClient) Method(ctx context.Context, in *Msg) (*Msg, error) { + ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") + ctx = ctxsetters.WithServiceName(ctx, "Svc") + ctx = ctxsetters.WithMethodName(ctx, "Method") caller := c.callMethod if c.interceptor != nil { caller = func(ctx context.Context, req *Msg) (*Msg, error) { @@ -192,9 +195,6 @@ func (c *svcJSONClient) Method(ctx context.Context, in *Msg) (*Msg, error) { } func (c *svcJSONClient) callMethod(ctx context.Context, in *Msg) (*Msg, error) { - ctx = ctxsetters.WithPackageName(ctx, "twirp.internal.twirptest.source_relative") - ctx = ctxsetters.WithServiceName(ctx, "Svc") - ctx = ctxsetters.WithMethodName(ctx, "Method") out := new(Msg) ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) if err != nil { diff --git a/protoc-gen-twirp/generator.go b/protoc-gen-twirp/generator.go index 856f14b1..ba2cdec3 100644 --- a/protoc-gen-twirp/generator.go +++ b/protoc-gen-twirp/generator.go @@ -1011,6 +1011,9 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto inputType := t.goTypeName(method.GetInputType()) outputType := t.goTypeName(method.GetOutputType()) t.P(`func (c *`, structName, `) `, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `) (*`, outputType, `, error) {`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithPackageName(ctx, "`, pkgName, `")`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithServiceName(ctx, "`, servName, `")`) + t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) t.P(` caller := c.call`, methName) t.P(` if c.interceptor != nil {`) t.generateClientInterceptorCaller(method) @@ -1019,9 +1022,6 @@ func (t *twirp) generateClient(name string, file *descriptor.FileDescriptorProto t.P(`}`) t.P() t.P(`func (c *`, structName, `) call`, methName, `(ctx `, t.pkgs["context"], `.Context, in *`, inputType, `) (*`, outputType, `, error) {`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithPackageName(ctx, "`, pkgName, `")`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithServiceName(ctx, "`, servName, `")`) - t.P(` ctx = `, t.pkgs["ctxsetters"], `.WithMethodName(ctx, "`, methName, `")`) t.P(` out := new(`, outputType, `)`) t.P(` ctx, err := do`, name, `Request(ctx, c.client, c.opts.Hooks, c.urls[`, strconv.Itoa(i), `], in, out)`) t.P(` if err != nil {`)