diff --git a/interceptors/protovalidate/protovalidate.go b/interceptors/protovalidate/protovalidate.go index cf337db11..fac113ba2 100644 --- a/interceptors/protovalidate/protovalidate.go +++ b/interceptors/protovalidate/protovalidate.go @@ -29,7 +29,7 @@ func UnaryServerInterceptor(validator *protovalidate.Validator, opts ...Option) break } if err = validator.Validate(msg); err != nil { - return nil, status.Error(codes.InvalidArgument, err.Error()) + return nil, validationErrToStatus(err).Err() } default: return nil, errors.New("unsupported message type") @@ -63,12 +63,15 @@ func (w *wrappedServerStream) RecvMsg(m interface{}) error { return err } - msg := m.(proto.Message) + msg, ok := m.(proto.Message) + if !ok { + return errors.New("unsupported message type") + } if w.options.shouldIgnoreMessage(msg.ProtoReflect().Type()) { return nil } if err := w.validator.Validate(msg); err != nil { - return status.Error(codes.InvalidArgument, err.Error()) + return validationErrToStatus(err).Err() } return nil @@ -93,3 +96,17 @@ func (w *wrappedServerStream) Context() context.Context { func wrapServerStream(stream grpc.ServerStream) *wrappedServerStream { return &wrappedServerStream{ServerStream: stream, wrappedContext: stream.Context()} } + +func validationErrToStatus(err error) *status.Status { + // Message is invalid. + if valErr := new(protovalidate.ValidationError); errors.As(err, &valErr) { + st := status.New(codes.InvalidArgument, err.Error()) + ds, detErr := st.WithDetails(valErr.ToProto()) + if detErr != nil { + return st + } + return ds + } + // CEL expression doesn't compile or type-check. + return status.New(codes.Unknown, err.Error()) +} diff --git a/interceptors/protovalidate/protovalidate_test.go b/interceptors/protovalidate/protovalidate_test.go index f626f395c..052d2395b 100644 --- a/interceptors/protovalidate/protovalidate_test.go +++ b/interceptors/protovalidate/protovalidate_test.go @@ -9,16 +9,19 @@ import ( "net" "testing" + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "github.com/bufbuild/protovalidate-go" protovalidate_middleware "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/protovalidate" "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate" testvalidatev1 "github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testvalidate/v1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" ) @@ -31,25 +34,27 @@ func TestUnaryServerInterceptor(t *testing.T) { handler := func(ctx context.Context, req any) (any, error) { return "good", nil } + info := &grpc.UnaryServerInfo{FullMethod: "FakeMethod"} t.Run("valid_email", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - resp, err := interceptor(context.TODO(), testvalidate.GoodUnaryRequest, info, handler) assert.Nil(t, err) assert.Equal(t, resp, "good") }) t.Run("invalid_email", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - _, err = interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) + assertEqualViolation(t, &validate.Violation{ + FieldPath: "message", + ConstraintId: "string.email", + Message: "value must be a valid email address", + }, err) + }) + + t.Run("not_protobuf", func(t *testing.T) { + _, err = interceptor(context.Background(), "not protobuf", info, handler) assert.Error(t, err) - assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assert.Equal(t, codes.Unknown, status.Code(err)) }) interceptor = protovalidate_middleware.UnaryServerInterceptor(validator, @@ -57,10 +62,6 @@ func TestUnaryServerInterceptor(t *testing.T) { ) t.Run("invalid_email_ignored", func(t *testing.T) { - info := &grpc.UnaryServerInfo{ - FullMethod: "FakeMethod", - } - resp, err := interceptor(context.TODO(), testvalidate.BadUnaryRequest, info, handler) assert.Nil(t, err) assert.Equal(t, resp, "good") @@ -145,8 +146,11 @@ func TestStreamServerInterceptor(t *testing.T) { assert.Nil(t, err) _, err = out.Recv() - assert.Error(t, err) - assert.Equal(t, codes.InvalidArgument, status.Code(err)) + assertEqualViolation(t, &validate.Violation{ + FieldPath: "message", + ConstraintId: "string.email", + Message: "value must be a valid email address", + }, err) }) t.Run("invalid_email_ignored", func(t *testing.T) { @@ -161,3 +165,19 @@ func TestStreamServerInterceptor(t *testing.T) { assert.Nil(t, err) }) } + +func assertEqualViolation(tb testing.TB, want *validate.Violation, got error) bool { + require.Error(tb, got) + st := status.Convert(got) + assert.Equal(tb, codes.InvalidArgument, st.Code()) + details := st.Proto().GetDetails() + require.Len(tb, details, 1) + gotpb, unwrapErr := details[0].UnmarshalNew() + require.Nil(tb, unwrapErr) + violations := &validate.Violations{ + Violations: []*validate.Violation{want}, + } + tb.Logf("got: %v", gotpb) + tb.Logf("want: %v", violations) + return assert.True(tb, proto.Equal(gotpb, violations)) +}