Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2: validator support for protoc-gen-validate 0.6.0 #418

Merged
merged 1 commit into from
Apr 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions interceptors/validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,40 @@ import (
"google.golang.org/grpc/status"
)

// The validate interface starting with protoc-gen-validate v0.6.0.
// See https://github.com/envoyproxy/protoc-gen-validate/pull/455.
type validator interface {
Validate(all bool) error
}

// The validate interface prior to protoc-gen-validate v0.6.0.
type validatorLegacy interface {
Validate() error
}

// Calls the Validate function on a proto message using either the current or legacy interface if the Validate function
// is present. If validation fails, the error is wrapped with `InvalidArgument` and returned.
func validate(req interface{}) error {
switch v := req.(type) {
case validatorLegacy:
if err := v.Validate(); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
case validator:
if err := v.Validate(false); err != nil {
return status.Error(codes.InvalidArgument, err.Error())
}
}
return nil
}

// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
//
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
if v, ok := req.(validator); ok {
if err := v.Validate(); err != nil {
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}
if err := validate(req); err != nil {
return nil, err
}
return handler(ctx, req)
}
Expand All @@ -34,10 +55,8 @@ func UnaryServerInterceptor() grpc.UnaryServerInterceptor {
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
func UnaryClientInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
if v, ok := req.(validator); ok {
if err := v.Validate(); err != nil {
return status.Errorf(codes.InvalidArgument, err.Error())
}
if err := validate(req); err != nil {
return err
}
return invoker(ctx, method, req, reply, cc, opts...)
}
Expand All @@ -64,10 +83,8 @@ func (s *recvWrapper) RecvMsg(m interface{}) error {
if err := s.ServerStream.RecvMsg(m); err != nil {
return err
}
if v, ok := m.(validator); ok {
if err := v.Validate(); err != nil {
return status.Errorf(codes.InvalidArgument, err.Error())
}
if err := validate(m); err != nil {
return err
}
return nil
}
17 changes: 12 additions & 5 deletions interceptors/validator/validator_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright 2016 Michal Witkowski. All Rights Reserved.
// See LICENSE for licensing terms.

package validator_test
package validator

import (
"io"
Expand All @@ -14,16 +14,23 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
)

func TestValidateWrapper(t *testing.T) {
assert.NoError(t, validate(testpb.GoodPing))
assert.Error(t, validate(testpb.BadPing))

assert.NoError(t, validate(testpb.GoodPingResponse))
assert.Error(t, validate(testpb.BadPingResponse))
}

func TestValidatorTestSuite(t *testing.T) {
s := &ValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ServerOpts: []grpc.ServerOption{
grpc.StreamInterceptor(validator.StreamServerInterceptor()),
grpc.UnaryInterceptor(validator.UnaryServerInterceptor()),
grpc.StreamInterceptor(StreamServerInterceptor()),
grpc.UnaryInterceptor(UnaryServerInterceptor()),
},
},
}
Expand All @@ -32,7 +39,7 @@ func TestValidatorTestSuite(t *testing.T) {
cs := &ClientValidatorTestSuite{
InterceptorTestSuite: &testpb.InterceptorTestSuite{
ClientOpts: []grpc.DialOption{
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor()),
grpc.WithUnaryInterceptor(UnaryClientInterceptor()),
},
},
}
Expand Down
25 changes: 20 additions & 5 deletions testing/testpb/test.manual_validator.pb.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,48 @@

package testpb

import "github.com/pkg/errors"
import (
"math"

func (x *PingRequest) Validate() error {
"github.com/pkg/errors"
)

func (x *PingRequest) Validate(bool) error {
if x.SleepTimeMs > 10000 {
return errors.New("cannot sleep for more than 10s")
}
return nil
}

func (x *PingErrorRequest) Validate() error {
func (x *PingErrorRequest) Validate(bool) error {
if x.SleepTimeMs > 10000 {
return errors.New("cannot sleep for more than 10s")
}
return nil
}

func (x *PingListRequest) Validate() error {
func (x *PingListRequest) Validate(bool) error {
if x.SleepTimeMs > 10000 {
return errors.New("cannot sleep for more than 10s")
}
return nil
}

func (x *PingStreamRequest) Validate() error {
func (x *PingStreamRequest) Validate(bool) error {
if x.SleepTimeMs > 10000 {
return errors.New("cannot sleep for more than 10s")
}
return nil
}

// Implements the legacy validation interface from protoc-gen-validate.
func (x *PingResponse) Validate() error {
if x.Counter > math.MaxInt16 {
return errors.New("ping allocation exceeded")
}
return nil
}

var (
GoodPing = &PingRequest{Value: "something", SleepTimeMs: 9999}
GoodPingError = &PingErrorRequest{Value: "something", SleepTimeMs: 9999}
Expand All @@ -42,4 +54,7 @@ var (
BadPingError = &PingErrorRequest{Value: "something", SleepTimeMs: 10001}
BadPingList = &PingListRequest{Value: "something", SleepTimeMs: 10001}
BadPingStream = &PingStreamRequest{Value: "something", SleepTimeMs: 10001}

GoodPingResponse = &PingResponse{Counter: 100}
BadPingResponse = &PingResponse{Counter: math.MaxInt16 + 1}
)