Skip to content

Commit 0e1142d

Browse files
authored
Removed deciders; Cleaned up validators. (#554)
Signed-off-by: bwplotka <[email protected]>
1 parent 8c53766 commit 0e1142d

File tree

6 files changed

+78
-110
lines changed

6 files changed

+78
-110
lines changed

interceptors/validator/interceptors.go

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,12 @@ import (
1212
// UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages.
1313
//
1414
// Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers.
15-
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
16-
// returns ALL validation error as a wrapped multi-error.
17-
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
1815
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
1916
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
2017
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
21-
o := evaluateServerOpt(opts)
18+
o := evaluateOpts(opts)
2219
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
23-
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
20+
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
2421
return nil, err
2522
}
2623
return handler(ctx, req)
@@ -30,15 +27,12 @@ func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
3027
// UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages.
3128
//
3229
// Invalid messages will be rejected with `InvalidArgument` before sending the request to server.
33-
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
34-
// returns ALL validation error as a wrapped multi-error.
35-
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
3630
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
3731
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
3832
func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
39-
o := evaluateClientOpt(opts)
33+
o := evaluateOpts(opts)
4034
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
41-
if err := validate(req, o.shouldFailFast, o.level, o.logger); err != nil {
35+
if err := validate(ctx, req, o.shouldFailFast, o.onValidationErrFunc); err != nil {
4236
return err
4337
}
4438
return invoker(ctx, method, req, reply, cc, opts...)
@@ -47,17 +41,14 @@ func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
4741

4842
// StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages.
4943
//
50-
// If `WithFailFast` used it will interceptor and returns the first validation error. Otherwise, the interceptor
51-
// returns ALL validation error as a wrapped multi-error.
52-
// If `WithLogger` used it will log all the validation errors. Otherwise, no default logging.
5344
// Note that generated codes prior to protoc-gen-validate v0.6.0 do not provide an all-validation
5445
// interface. In this case the interceptor fallbacks to legacy validation and `all` is ignored.
5546
// The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the
5647
// type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace
5748
// handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on
5849
// calls to `stream.Recv()`.
5950
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
60-
o := evaluateServerOpt(opts)
51+
o := evaluateOpts(opts)
6152
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
6253
wrapper := &recvWrapper{
6354
options: o,
@@ -77,7 +68,7 @@ func (s *recvWrapper) RecvMsg(m any) error {
7768
if err := s.ServerStream.RecvMsg(m); err != nil {
7869
return err
7970
}
80-
if err := validate(m, s.shouldFailFast, s.level, s.logger); err != nil {
71+
if err := validate(s.Context(), m, s.shouldFailFast, s.onValidationErrFunc); err != nil {
8172
return err
8273
}
8374
return nil

interceptors/validator/interceptors_test.go

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"io"
99
"testing"
1010

11-
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
1211
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/validator"
1312
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
1413
"github.com/stretchr/testify/assert"
@@ -19,10 +18,6 @@ import (
1918
"google.golang.org/grpc/status"
2019
)
2120

22-
type TestLogger struct{}
23-
24-
func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}
25-
2621
type ValidatorTestSuite struct {
2722
*testpb.InterceptorTestSuite
2823
}
@@ -104,35 +99,42 @@ func TestValidatorTestSuite(t *testing.T) {
10499
}
105100
suite.Run(t, sWithNoArgs)
106101

107-
sWithWithFailFastArgs := &ValidatorTestSuite{
102+
sWithFailFastArgs := &ValidatorTestSuite{
108103
InterceptorTestSuite: &testpb.InterceptorTestSuite{
109104
ServerOpts: []grpc.ServerOption{
110105
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast())),
111106
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
112107
},
113108
},
114109
}
115-
suite.Run(t, sWithWithFailFastArgs)
110+
suite.Run(t, sWithFailFastArgs)
116111

117-
sWithWithLoggerArgs := &ValidatorTestSuite{
112+
var gotErrMsgs []string
113+
onErr := func(ctx context.Context, err error) {
114+
gotErrMsgs = append(gotErrMsgs, err.Error())
115+
}
116+
sWithOnErrFuncArgs := &ValidatorTestSuite{
118117
InterceptorTestSuite: &testpb.InterceptorTestSuite{
119118
ServerOpts: []grpc.ServerOption{
120-
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
121-
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
119+
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
120+
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
122121
},
123122
},
124123
}
125-
suite.Run(t, sWithWithLoggerArgs)
124+
suite.Run(t, sWithOnErrFuncArgs)
125+
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)
126126

127+
gotErrMsgs = gotErrMsgs[:0]
127128
sAll := &ValidatorTestSuite{
128129
InterceptorTestSuite: &testpb.InterceptorTestSuite{
129130
ServerOpts: []grpc.ServerOption{
130-
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
131-
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
131+
grpc.StreamInterceptor(validator.StreamServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
132+
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
132133
},
133134
},
134135
}
135136
suite.Run(t, sAll)
137+
require.Equal(t, []string{"cannot sleep for more than 10s", "cannot sleep for more than 10s", "cannot sleep for more than 10s"}, gotErrMsgs)
136138

137139
csWithNoArgs := &ClientValidatorTestSuite{
138140
InterceptorTestSuite: &testpb.InterceptorTestSuite{
@@ -143,30 +145,34 @@ func TestValidatorTestSuite(t *testing.T) {
143145
}
144146
suite.Run(t, csWithNoArgs)
145147

146-
csWithWithFailFastArgs := &ClientValidatorTestSuite{
148+
csWithFailFastArgs := &ClientValidatorTestSuite{
147149
InterceptorTestSuite: &testpb.InterceptorTestSuite{
148150
ServerOpts: []grpc.ServerOption{
149151
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithFailFast())),
150152
},
151153
},
152154
}
153-
suite.Run(t, csWithWithFailFastArgs)
155+
suite.Run(t, csWithFailFastArgs)
154156

155-
csWithWithLoggerArgs := &ClientValidatorTestSuite{
157+
gotErrMsgs = gotErrMsgs[:0]
158+
csWithOnErrFuncArgs := &ClientValidatorTestSuite{
156159
InterceptorTestSuite: &testpb.InterceptorTestSuite{
157160
ServerOpts: []grpc.ServerOption{
158-
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithLogger(logging.LevelDebug, &TestLogger{}))),
161+
grpc.UnaryInterceptor(validator.UnaryServerInterceptor(validator.WithOnValidationErrFunc(onErr))),
159162
},
160163
},
161164
}
162-
suite.Run(t, csWithWithLoggerArgs)
165+
suite.Run(t, csWithOnErrFuncArgs)
166+
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)
163167

168+
gotErrMsgs = gotErrMsgs[:0]
164169
csAll := &ClientValidatorTestSuite{
165170
InterceptorTestSuite: &testpb.InterceptorTestSuite{
166171
ClientOpts: []grpc.DialOption{
167-
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast())),
172+
grpc.WithUnaryInterceptor(validator.UnaryClientInterceptor(validator.WithFailFast(), validator.WithOnValidationErrFunc(onErr))),
168173
},
169174
},
170175
}
171176
suite.Run(t, csAll)
177+
require.Equal(t, []string{"cannot sleep for more than 10s"}, gotErrMsgs)
172178
}

interceptors/validator/options.go

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,35 @@
33

44
package validator
55

6-
import "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
6+
import (
7+
"context"
8+
)
79

810
type options struct {
9-
level logging.Level
10-
logger logging.Logger
11-
shouldFailFast bool
11+
shouldFailFast bool
12+
onValidationErrFunc OnValidationErr
1213
}
1314
type Option func(*options)
1415

15-
func evaluateServerOpt(opts []Option) *options {
16+
func evaluateOpts(opts []Option) *options {
1617
optCopy := &options{}
1718
for _, o := range opts {
1819
o(optCopy)
1920
}
2021
return optCopy
2122
}
2223

23-
func evaluateClientOpt(opts []Option) *options {
24-
optCopy := &options{}
25-
for _, o := range opts {
26-
o(optCopy)
27-
}
28-
return optCopy
29-
}
24+
type OnValidationErr func(ctx context.Context, err error)
3025

31-
// WithLogger tells validator to log all the validation errors with the given log level.
32-
func WithLogger(level logging.Level, logger logging.Logger) Option {
26+
// WithOnValidationErrFunc registers function that will be invoked on validation error(s).
27+
func WithOnValidationErrFunc(onValidationErrFunc OnValidationErr) Option {
3328
return func(o *options) {
34-
o.level = level
35-
o.logger = logger
29+
o.onValidationErrFunc = onValidationErrFunc
3630
}
3731
}
3832

3933
// WithFailFast tells validator to immediately stop doing further validation after first validation error.
34+
// This option is ignored if message is only supporting validator.validatorLegacy interface.
4035
func WithFailFast() Option {
4136
return func(o *options) {
4237
o.shouldFailFast = true

interceptors/validator/validator.go

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package validator
66
import (
77
"context"
88

9-
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
109
"google.golang.org/grpc/codes"
1110
"google.golang.org/grpc/status"
1211
)
@@ -28,51 +27,31 @@ type validatorLegacy interface {
2827
Validate() error
2928
}
3029

31-
func log(level logging.Level, logger logging.Logger, msg string) {
32-
if logger != nil {
33-
// TODO(bwplotka): Fix in separate PR.
34-
logger.Log(context.TODO(), level, msg)
35-
}
36-
}
37-
38-
func validate(req interface{}, shouldFailFast bool, level logging.Level, logger logging.Logger) error {
39-
// shouldFailFast tells validator to immediately stop doing further validation after first validation error.
30+
func validate(ctx context.Context, reqOrRes interface{}, shouldFailFast bool, onValidationErrFunc OnValidationErr) (err error) {
4031
if shouldFailFast {
41-
switch v := req.(type) {
32+
switch v := reqOrRes.(type) {
4233
case validatorLegacy:
43-
if err := v.Validate(); err != nil {
44-
log(level, logger, err.Error())
45-
return status.Error(codes.InvalidArgument, err.Error())
46-
}
34+
err = v.Validate()
4735
case validator:
48-
if err := v.Validate(false); err != nil {
49-
log(level, logger, err.Error())
50-
return status.Error(codes.InvalidArgument, err.Error())
51-
}
36+
err = v.Validate(false)
37+
}
38+
} else {
39+
switch v := reqOrRes.(type) {
40+
case validateAller:
41+
err = v.ValidateAll()
42+
case validator:
43+
err = v.Validate(true)
44+
case validatorLegacy:
45+
err = v.Validate()
5246
}
47+
}
5348

49+
if err == nil {
5450
return nil
5551
}
5652

57-
// shouldNotFailFast tells validator to continue doing further validation even if after a validation error.
58-
switch v := req.(type) {
59-
case validateAller:
60-
if err := v.ValidateAll(); err != nil {
61-
log(level, logger, err.Error())
62-
return status.Error(codes.InvalidArgument, err.Error())
63-
}
64-
case validator:
65-
if err := v.Validate(true); err != nil {
66-
log(level, logger, err.Error())
67-
return status.Error(codes.InvalidArgument, err.Error())
68-
}
69-
case validatorLegacy:
70-
// Fallback to legacy validator
71-
if err := v.Validate(); err != nil {
72-
log(level, logger, err.Error())
73-
return status.Error(codes.InvalidArgument, err.Error())
74-
}
53+
if onValidationErrFunc != nil {
54+
onValidationErrFunc(ctx, err)
7555
}
76-
77-
return nil
56+
return status.Error(codes.InvalidArgument, err.Error())
7857
}

interceptors/validator/validator_test.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,25 @@ import (
77
"context"
88
"testing"
99

10-
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
1110
"github.com/grpc-ecosystem/go-grpc-middleware/v2/testing/testpb"
1211
"github.com/stretchr/testify/assert"
1312
)
1413

15-
type TestLogger struct{}
16-
17-
func (l *TestLogger) Log(ctx context.Context, level logging.Level, msg string, fields ...any) {}
18-
1914
func TestValidateWrapper(t *testing.T) {
20-
assert.NoError(t, validate(testpb.GoodPing, false, logging.LevelError, &TestLogger{}))
21-
assert.Error(t, validate(testpb.BadPing, false, logging.LevelError, &TestLogger{}))
22-
assert.NoError(t, validate(testpb.GoodPing, true, logging.LevelError, &TestLogger{}))
23-
assert.Error(t, validate(testpb.BadPing, true, logging.LevelError, &TestLogger{}))
24-
25-
assert.NoError(t, validate(testpb.GoodPingError, false, logging.LevelError, &TestLogger{}))
26-
assert.Error(t, validate(testpb.BadPingError, false, logging.LevelError, &TestLogger{}))
27-
assert.NoError(t, validate(testpb.GoodPingError, true, logging.LevelError, &TestLogger{}))
28-
assert.Error(t, validate(testpb.BadPingError, true, logging.LevelError, &TestLogger{}))
29-
30-
assert.NoError(t, validate(testpb.GoodPingResponse, false, logging.LevelError, &TestLogger{}))
31-
assert.NoError(t, validate(testpb.GoodPingResponse, true, logging.LevelError, &TestLogger{}))
32-
assert.Error(t, validate(testpb.BadPingResponse, false, logging.LevelError, &TestLogger{}))
33-
assert.Error(t, validate(testpb.BadPingResponse, true, logging.LevelError, &TestLogger{}))
15+
ctx := context.Background()
16+
17+
assert.NoError(t, validate(ctx, testpb.GoodPing, false, nil))
18+
assert.Error(t, validate(ctx, testpb.BadPing, false, nil))
19+
assert.NoError(t, validate(ctx, testpb.GoodPing, true, nil))
20+
assert.Error(t, validate(ctx, testpb.BadPing, true, nil))
21+
22+
assert.NoError(t, validate(ctx, testpb.GoodPingError, false, nil))
23+
assert.Error(t, validate(ctx, testpb.BadPingError, false, nil))
24+
assert.NoError(t, validate(ctx, testpb.GoodPingError, true, nil))
25+
assert.Error(t, validate(ctx, testpb.BadPingError, true, nil))
26+
27+
assert.NoError(t, validate(ctx, testpb.GoodPingResponse, false, nil))
28+
assert.NoError(t, validate(ctx, testpb.GoodPingResponse, true, nil))
29+
assert.Error(t, validate(ctx, testpb.BadPingResponse, false, nil))
30+
assert.Error(t, validate(ctx, testpb.BadPingResponse, true, nil))
3431
}

testing/testpb/test.manual_validator.pb.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@ func (x *PingStreamRequest) Validate(bool) error {
3636
return nil
3737
}
3838

39-
// Implements the legacy validation interface from protoc-gen-validate.
39+
// Validate implements the legacy validation interface from protoc-gen-validate.
4040
func (x *PingResponse) Validate() error {
4141
if x.Counter > math.MaxInt16 {
4242
return errors.New("ping allocation exceeded")
4343
}
4444
return nil
4545
}
4646

47-
// Implements the new ValidateAll interface from protoc-gen-validate.
47+
// ValidateAll implements the new ValidateAll interface from protoc-gen-validate.
4848
func (x *PingResponse) ValidateAll() error {
4949
if x.Counter > math.MaxInt16 {
5050
return errors.New("ping allocation exceeded")

0 commit comments

Comments
 (0)