diff --git a/README.md b/README.md index a0b0dc198..a45451a48 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,8 @@ This list covers known interceptors that users use for their Go microservices (b All paths should work with `go get `. #### Auth -* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth`](interceptors/auth) - a customizable (via `AuthFunc`) piece of auth middleware. +* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/auth`](interceptors/auth) - a customizable via `AuthFunc` piece of auth middleware. +* (external) [`google.golang.org/grpc/authz`](https://github.com/grpc/grpc-go/blob/master/authz/grpc_authz_server_interceptors.go) - more complex, customizable via auth polices (RBAC like), piece of auth middleware. #### Observability * Metrics with [`github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus`⚡](providers/prometheus) - Prometheus client-side and server-side monitoring middleware. Supports exemplars. Moved from deprecated now [`go-grpc-prometheus`](https://github.com/grpc-ecosystem/go-grpc-prometheus). It's a separate module, so core module has limited number of dependencies. @@ -54,8 +55,9 @@ All paths should work with `go get `. * (external) [`github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing`](https://pkg.go.dev/github.com/grpc-ecosystem/go-grpc-middleware@v1.4.0/tracing/opentracing) - deprecated [OpenTracing](http://opentracing.io/) client-side and server-side interceptors if you still need it! #### Client -* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry`](interceptors/retry) - a generic gRPC response code retry mechanism, client-side middleware -* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/timeout`](interceptors/timeout) - a generic gRPC request timeout, client-side middleware +* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/retry`](interceptors/retry) - a generic gRPC response code retry mechanism, client-side middleware. + * NOTE: grpc-go has native retries too with advanced policies (https://github.com/grpc/grpc-go/blob/v1.54.0/examples/features/retry/client/main.go) +* [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/timeout`](interceptors/timeout) - a generic gRPC request timeout, client-side middleware. #### Server * [`github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator`](interceptors/validator) - codegen inbound message validation from `.proto` options. diff --git a/interceptors/auth/auth.go b/interceptors/auth/auth.go index beaa77d7d..b015d60e5 100644 --- a/interceptors/auth/auth.go +++ b/interceptors/auth/auth.go @@ -32,6 +32,7 @@ type ServiceAuthFuncOverride interface { } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request auth. +// NOTE(bwplotka): For more complex auth interceptor see https://github.com/grpc/grpc-go/blob/master/authz/grpc_authz_server_interceptors.go. func UnaryServerInterceptor(authFunc AuthFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { var newCtx context.Context @@ -49,6 +50,7 @@ func UnaryServerInterceptor(authFunc AuthFunc) grpc.UnaryServerInterceptor { } // StreamServerInterceptor returns a new unary server interceptors that performs per-request auth. +// NOTE(bwplotka): For more complex auth interceptor see https://github.com/grpc/grpc-go/blob/master/authz/grpc_authz_server_interceptors.go. func StreamServerInterceptor(authFunc AuthFunc) grpc.StreamServerInterceptor { return func(srv any, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { var newCtx context.Context diff --git a/interceptors/retry/backoff.go b/interceptors/retry/backoff.go index 49a9abe92..5b8aaa683 100644 --- a/interceptors/retry/backoff.go +++ b/interceptors/retry/backoff.go @@ -4,13 +4,14 @@ package retry import ( + "context" "math/rand" "time" ) // BackoffLinear is very simple: it waits for a fixed period of time between calls. func BackoffLinear(waitBetween time.Duration) BackoffFunc { - return func(attempt uint) time.Duration { + return func(ctx context.Context, attempt uint) time.Duration { return waitBetween } } @@ -31,7 +32,7 @@ func exponentBase2(a uint) uint { // BackoffLinearWithJitter waits a set period of time, allowing for jitter (fractional adjustment). // For example waitBetween=1s and jitter=0.10 can generate waits between 900ms and 1100ms. func BackoffLinearWithJitter(waitBetween time.Duration, jitterFraction float64) BackoffFunc { - return func(attempt uint) time.Duration { + return func(ctx context.Context, attempt uint) time.Duration { return jitterUp(waitBetween, jitterFraction) } } @@ -40,7 +41,7 @@ func BackoffLinearWithJitter(waitBetween time.Duration, jitterFraction float64) // The scalar is multiplied times 2 raised to the current attempt. So the first // retry with a scalar of 100ms is 100ms, while the 5th attempt would be 1.6s. func BackoffExponential(scalar time.Duration) BackoffFunc { - return func(attempt uint) time.Duration { + return func(ctx context.Context, attempt uint) time.Duration { return scalar * time.Duration(exponentBase2(attempt)) } } @@ -48,7 +49,7 @@ func BackoffExponential(scalar time.Duration) BackoffFunc { // BackoffExponentialWithJitter creates an exponential backoff like // BackoffExponential does, but adds jitter. func BackoffExponentialWithJitter(scalar time.Duration, jitterFraction float64) BackoffFunc { - return func(attempt uint) time.Duration { + return func(ctx context.Context, attempt uint) time.Duration { return jitterUp(scalar*time.Duration(exponentBase2(attempt)), jitterFraction) } } diff --git a/interceptors/retry/options.go b/interceptors/retry/options.go index 9b1440d70..649db2028 100644 --- a/interceptors/retry/options.go +++ b/interceptors/retry/options.go @@ -23,9 +23,7 @@ var ( perCallTimeout: 0, // disabled includeHeader: true, codes: DefaultRetriableCodes, - backoffFunc: BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { - return BackoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10)(attempt) - }), + backoffFunc: BackoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10), onRetryCallback: OnRetryCallback(func(ctx context.Context, attempt uint, err error) { logTrace(ctx, "grpc_retry attempt: %d, backoff for %v", attempt, err) }), @@ -37,16 +35,8 @@ var ( // They are called with an identifier of the attempt, and should return a time the system client should // hold off for. If the time returned is longer than the `context.Context.Deadline` of the request // the deadline of the request takes precedence and the wait will be interrupted before proceeding -// with the next iteration. -type BackoffFunc func(attempt uint) time.Duration - -// BackoffFuncContext denotes a family of functions that control the backoff duration between call retries. -// -// They are called with an identifier of the attempt, and should return a time the system client should -// hold off for. If the time returned is longer than the `context.Context.Deadline` of the request -// the deadline of the request takes precedence and the wait will be interrupted before proceeding // with the next iteration. The context can be used to extract request scoped metadata and context values. -type BackoffFuncContext func(ctx context.Context, attempt uint) time.Duration +type BackoffFunc func(ctx context.Context, attempt uint) time.Duration // OnRetryCallback is the type of function called when a retry occurs. type OnRetryCallback func(ctx context.Context, attempt uint, err error) @@ -67,15 +57,6 @@ func WithMax(maxRetries uint) CallOption { // WithBackoff sets the `BackoffFunc` used to control time between retries. func WithBackoff(bf BackoffFunc) CallOption { - return CallOption{applyFunc: func(o *options) { - o.backoffFunc = BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { - return bf(attempt) - }) - }} -} - -// WithBackoffContext sets the `BackoffFuncContext` used to control time between retries. -func WithBackoffContext(bf BackoffFuncContext) CallOption { return CallOption{applyFunc: func(o *options) { o.backoffFunc = bf }} @@ -124,7 +105,7 @@ type options struct { perCallTimeout time.Duration includeHeader bool codes []codes.Code - backoffFunc BackoffFuncContext + backoffFunc BackoffFunc onRetryCallback OnRetryCallback } diff --git a/interceptors/retry/retry_test.go b/interceptors/retry/retry_test.go index 08b1dc8f6..b423a93bf 100644 --- a/interceptors/retry/retry_test.go +++ b/interceptors/retry/retry_test.go @@ -178,40 +178,6 @@ func (s *RetrySuite) TestUnary_OverrideFromDialOpts() { require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") } -func (s *RetrySuite) TestUnary_PerCallDeadline_Succeeds() { - s.T().Skip("TODO(bwplotka): Mock time & unskip, this is too flaky on GH Actions.") - - // This tests 5 requests, with first 4 sleeping for 10 millisecond, and the retry logic firing - // a retry call with a 5 millisecond deadline. The 5th one doesn't sleep and succeeds. - deadlinePerCall := 5 * time.Millisecond - s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) - out, err := s.Client.Ping(s.SimpleCtx(), testpb.GoodPing, WithPerRetryTimeout(deadlinePerCall), - WithMax(5)) - require.NoError(s.T(), err, "the fifth invocation should succeed") - require.NotNil(s.T(), out, "Pong must be not nil") - require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") -} - -func (s *RetrySuite) TestUnary_PerCallDeadline_FailsOnParent() { - s.T().Skip("TODO(bwplotka): Mock time & unskip, this is too flaky on GH Actions.") - - // This tests that the parent context (passed to the invocation) takes precedence over retries. - // The parent context has 150 milliseconds of deadline. - // Each failed call sleeps for 100milliseconds, and there is 5 milliseconds between each one. - // This means that unlike in TestUnary_PerCallDeadline_Succeeds, the fifth successful call won't - // be made. - parentDeadline := 150 * time.Millisecond - deadlinePerCall := 50 * time.Millisecond - // All 0-4 requests should have 10 millisecond sleeps and deadline, while the last one works. - s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) - ctx, cancel := context.WithTimeout(context.TODO(), parentDeadline) - defer cancel() - _, err := s.Client.Ping(ctx, testpb.GoodPing, WithPerRetryTimeout(deadlinePerCall), - WithMax(5)) - require.Error(s.T(), err, "the retries must fail due to context deadline exceeded") - require.Equal(s.T(), codes.DeadlineExceeded, status.Code(err), "failre code must be a gRPC error of Deadline class") -} - func (s *RetrySuite) TestUnary_OnRetryCallbackCalled() { retryCallbackCount := 0 @@ -243,41 +209,6 @@ func (s *RetrySuite) TestServerStream_OverrideFromContext() { require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") } -func (s *RetrySuite) TestServerStream_PerCallDeadline_Succeeds() { - s.T().Skip("TODO(bwplotka): Mock time & unskip, this is too flaky on GH Actions.") - - // This tests 5 requests, with first 4 sleeping for 100 millisecond, and the retry logic firing - // a retry call with a 50 millisecond deadline. The 5th one doesn't sleep and succeeds. - deadlinePerCall := 100 * time.Millisecond - s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) - stream, err := s.Client.PingList(s.SimpleCtx(), testpb.GoodPingList, WithPerRetryTimeout(deadlinePerCall), - WithMax(5)) - require.NoError(s.T(), err, "establishing the connection must always succeed") - s.assertPingListWasCorrect(stream) - require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") -} - -func (s *RetrySuite) TestServerStream_PerCallDeadline_FailsOnParent() { - s.T().Skip("TODO(bwplotka): Mock time & unskip, this is too flaky on GH Actions.") - - // This tests that the parent context (passed to the invocation) takes precedence over retries. - // The parent context has 150 milliseconds of deadline. - // Each failed call sleeps for 50milliseconds, and there is 25 milliseconds between each one. - // This means that unlike in TestServerStream_PerCallDeadline_Succeeds, the fifth successful call won't - // be made. - parentDeadline := 150 * time.Millisecond - deadlinePerCall := 50 * time.Millisecond - // All 0-4 requests should have 10 millisecond sleeps and deadline, while the last one works. - s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) - parentCtx, cancel := context.WithTimeout(context.TODO(), parentDeadline) - defer cancel() - stream, err := s.Client.PingList(parentCtx, testpb.GoodPingList, WithPerRetryTimeout(deadlinePerCall), - WithMax(5)) - require.NoError(s.T(), err, "establishing the connection must always succeed") - _, err = stream.Recv() - require.Equal(s.T(), codes.DeadlineExceeded, status.Code(err), "failre code must be a gRPC error of Deadline class") -} - func (s *RetrySuite) TestServerStream_OnRetryCallbackCalled() { retryCallbackCount := 0 diff --git a/interceptors/validator/interceptors.go b/interceptors/validator/interceptors.go index 46c5ed0bd..7f259ac24 100644 --- a/interceptors/validator/interceptors.go +++ b/interceptors/validator/interceptors.go @@ -17,7 +17,7 @@ import ( func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { o := evaluateOpts(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { - if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil { + if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrCallback); err != nil { return nil, err } return handler(ctx, req) @@ -32,7 +32,7 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { o := evaluateOpts(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { - if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil { + if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrCallback); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) @@ -68,7 +68,7 @@ func (s *recvWrapper) RecvMsg(m any) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } - if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrFunc); err != nil { + if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrCallback); err != nil { return err } return nil diff --git a/interceptors/validator/interceptors_test.go b/interceptors/validator/interceptors_test.go index 43f8462ee..4e10bcbe2 100644 --- a/interceptors/validator/interceptors_test.go +++ b/interceptors/validator/interceptors_test.go @@ -116,8 +116,8 @@ func TestValidatorTestSuite(t *testing.T) { sWithOnErrFuncArgs := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrFunc(onErr))), - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))), + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrCallback(onErr))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrCallback(onErr))), }, }, } @@ -128,8 +128,8 @@ func TestValidatorTestSuite(t *testing.T) { sAll := &ValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), + grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrCallback(onErr))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrCallback(onErr))), }, }, } @@ -158,7 +158,7 @@ func TestValidatorTestSuite(t *testing.T) { csWithOnErrFuncArgs := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ - grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))), + grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrCallback(onErr))), }, }, } @@ -169,7 +169,7 @@ func TestValidatorTestSuite(t *testing.T) { csAll := &ClientValidatorTestSuite{ InterceptorTestSuite: &testpb.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ - grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))), + grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrCallback(onErr))), }, }, } diff --git a/interceptors/validator/options.go b/interceptors/validator/options.go index 3f4cc946b..6e392b4e8 100644 --- a/interceptors/validator/options.go +++ b/interceptors/validator/options.go @@ -8,8 +8,8 @@ import ( ) type options struct { - shouldFailFast bool - onValidationErrFunc OnValidationErr + shouldFailFast bool + onValidationErrCallback OnValidationErrCallback } type Option func(*options) @@ -21,12 +21,12 @@ func evaluateOpts(opts []Option) *options { return optCopy } -type OnValidationErr func(ctx context.Context, err error) +type OnValidationErrCallback func(ctx context.Context, err error) -// WithOnValidationErrFunc registers function that will be invoked on validation error(s). -func WithOnValidationErrFunc(onValidationErrFunc OnValidationErr) Option { +// WithOnValidationErrCallback registers function that will be invoked on validation error(s). +func WithOnValidationErrCallback(onValidationErrCallback OnValidationErrCallback) Option { return func(o *options) { - o.onValidationErrFunc = onValidationErrFunc + o.onValidationErrCallback = onValidationErrCallback } } diff --git a/interceptors/validator/validator.go b/interceptors/validator/validator.go index d6d72558c..ad6d25e09 100644 --- a/interceptors/validator/validator.go +++ b/interceptors/validator/validator.go @@ -27,7 +27,7 @@ type validatorLegacy interface { Validate() error } -func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrFunc OnValidationErr) (err error) { +func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrCallback OnValidationErrCallback) (err error) { if shouldFailFast { switch v := reqOrRes.(type) { case validatorLegacy: @@ -50,8 +50,8 @@ func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, on return nil } - if onValidationErrFunc != nil { - onValidationErrFunc(ctx, err) + if onValidationErrCallback != nil { + onValidationErrCallback(ctx, err) } return status.Error(codes.InvalidArgument, err.Error()) }