Skip to content

Commit 5f5b81c

Browse files
committed
add grpc request validator
1 parent 5edaef0 commit 5f5b81c

File tree

8 files changed

+278
-0
lines changed

8 files changed

+278
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
import io.grpc.Metadata;
4+
import io.grpc.ServerCall;
5+
import io.grpc.ServerCallHandler;
6+
import io.grpc.ServerInterceptor;
7+
8+
public class RequestValidationInterceptor implements ServerInterceptor {
9+
10+
private RequestValidatorResolver requestValidatorResolver;
11+
12+
public RequestValidationInterceptor(RequestValidatorResolver requestValidatorResolver) {
13+
this.requestValidatorResolver = requestValidatorResolver;
14+
}
15+
16+
@Override
17+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
18+
ServerCall<ReqT, RespT> call,
19+
Metadata headers,
20+
ServerCallHandler<ReqT, RespT> next
21+
) {
22+
ServerCall.Listener<ReqT> delegate = next.startCall(call, headers);
23+
24+
return new RequestValidationListener<>(delegate, call, headers, requestValidatorResolver);
25+
}
26+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
import com.google.protobuf.MessageLiteOrBuilder;
4+
import io.grpc.ForwardingServerCallListener;
5+
import io.grpc.Metadata;
6+
import io.grpc.ServerCall;
7+
import io.grpc.Status;
8+
9+
public class RequestValidationListener<ReqT, ResT> extends ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT> {
10+
11+
private ServerCall<ReqT, ResT> serverCall;
12+
private Metadata headers;
13+
private RequestValidatorResolver requestValidatorResolver;
14+
15+
public RequestValidationListener(
16+
ServerCall.Listener<ReqT> delegate,
17+
ServerCall<ReqT, ResT> serverCall,
18+
Metadata headers,
19+
RequestValidatorResolver requestValidatorResolver
20+
) {
21+
super(delegate);
22+
this.serverCall = serverCall;
23+
this.headers = headers;
24+
this.requestValidatorResolver = requestValidatorResolver;
25+
}
26+
27+
@Override
28+
public void onMessage(ReqT message) {
29+
MessageLiteOrBuilder convertMessage = (MessageLiteOrBuilder) message;
30+
RequestValidator<MessageLiteOrBuilder> validator = requestValidatorResolver.find(convertMessage.getClass().getTypeName());
31+
32+
if (validator == null) {
33+
super.onMessage(message);
34+
} else {
35+
try {
36+
ValidationResult validationResult = validator.isValid(convertMessage);
37+
38+
if (validationResult.isValid()) {
39+
super.onMessage(message);
40+
} else {
41+
Status status = Status.INVALID_ARGUMENT
42+
.withDescription("invalid argument. " + validationResult.getMessage());
43+
handleInvalidRequest(status);
44+
}
45+
} catch (Exception e) {
46+
Status status = Status.INTERNAL.withDescription(e.getMessage());
47+
48+
handleInvalidRequest(status);
49+
}
50+
}
51+
}
52+
53+
private void handleInvalidRequest(Status status) {
54+
if (!serverCall.isCancelled()) {
55+
serverCall.close(status, headers);
56+
}
57+
}
58+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
import com.google.protobuf.MessageLiteOrBuilder;
4+
5+
interface RequestValidator<T extends MessageLiteOrBuilder> {
6+
ValidationResult isValid(T request);
7+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
import com.google.protobuf.MessageLiteOrBuilder;
4+
5+
import java.lang.reflect.ParameterizedType;
6+
import java.lang.reflect.Type;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.stream.Collectors;
10+
11+
public class RequestValidatorResolver {
12+
13+
private List<RequestValidator<MessageLiteOrBuilder>> validators;
14+
15+
private Map<String, RequestValidator<MessageLiteOrBuilder>> validatorMap;
16+
17+
public RequestValidatorResolver(List<RequestValidator<MessageLiteOrBuilder>> validators) {
18+
this.validators = validators;
19+
20+
validatorMap = validators.stream()
21+
.collect(Collectors.toMap(this::getClassName, it -> it));
22+
}
23+
24+
private String getClassName(RequestValidator<MessageLiteOrBuilder> it) {
25+
Type[] genericInterfaces = it.getClass().getGenericInterfaces();
26+
27+
if (genericInterfaces.length == 0) {
28+
return null;
29+
}
30+
31+
return ((ParameterizedType) genericInterfaces[0]).getActualTypeArguments()[0].getTypeName();
32+
}
33+
34+
public RequestValidator<MessageLiteOrBuilder> find(String typeName) {
35+
return validatorMap.get(typeName);
36+
}
37+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
public class ValidationResult {
4+
5+
private boolean isValid;
6+
7+
private String message;
8+
9+
public ValidationResult(boolean isValid, String message) {
10+
this.isValid = isValid;
11+
this.message = message;
12+
}
13+
14+
public boolean isValid() {
15+
return isValid;
16+
}
17+
18+
public String getMessage() {
19+
return message;
20+
}
21+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package com.linecorp.armeria.internal.common.grpc;
2+
3+
import com.linecorp.armeria.grpc.testing.Hello;
4+
import com.linecorp.armeria.grpc.testing.HelloServiceGrpc;
5+
import io.grpc.stub.StreamObserver;
6+
7+
public class HeloServiceImpl extends HelloServiceGrpc.HelloServiceImplBase {
8+
9+
@Override
10+
public void hello(Hello.HelloRequest request, StreamObserver<Hello.HelloResponse> responseObserver) {
11+
Hello.HelloResponse response = Hello.HelloResponse.newBuilder()
12+
.setMessage("success")
13+
.build();
14+
15+
responseObserver.onNext(response);
16+
responseObserver.onCompleted();
17+
}
18+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package com.linecorp.armeria.server.grpc.validation;
2+
3+
import com.google.protobuf.MessageLiteOrBuilder;
4+
import com.linecorp.armeria.client.grpc.GrpcClients;
5+
import com.linecorp.armeria.grpc.testing.Hello;
6+
import com.linecorp.armeria.grpc.testing.HelloServiceGrpc.HelloServiceBlockingStub;
7+
import com.linecorp.armeria.internal.common.grpc.HeloServiceImpl;
8+
import com.linecorp.armeria.server.ServerBuilder;
9+
import com.linecorp.armeria.server.grpc.GrpcService;
10+
import com.linecorp.armeria.testing.junit5.server.ServerExtension;
11+
import io.grpc.Status;
12+
import io.grpc.StatusRuntimeException;
13+
import org.junit.jupiter.api.Test;
14+
import org.junit.jupiter.api.extension.RegisterExtension;
15+
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
19+
import static org.assertj.core.api.Assertions.assertThat;
20+
import static org.assertj.core.api.Assertions.catchThrowable;
21+
22+
class RequestValidationInterceptorTest {
23+
24+
static String ERROR_MESSAGE = "invalid argument";
25+
26+
@RegisterExtension
27+
static ServerExtension server = new ServerExtension() {
28+
@Override
29+
protected void configure(ServerBuilder sb) {
30+
List<RequestValidator<MessageLiteOrBuilder>> validators = new ArrayList<>();
31+
32+
validators.add((RequestValidator) new HelloRequestValidator());
33+
34+
RequestValidatorResolver requestValidatorResolver = new RequestValidatorResolver(validators);
35+
sb.service(GrpcService.builder()
36+
.addService(new HeloServiceImpl())
37+
.intercept(new RequestValidationInterceptor(requestValidatorResolver))
38+
.build());
39+
}
40+
};
41+
42+
@Test
43+
void validation_fail_test() {
44+
HelloServiceBlockingStub client = GrpcClients.builder(server.httpUri())
45+
.build(HelloServiceBlockingStub.class);
46+
47+
final Throwable cause = catchThrowable(() -> client.hello(Hello.HelloRequest.getDefaultInstance()));
48+
assertThat(cause).isInstanceOf(StatusRuntimeException.class);
49+
assertThat(((StatusRuntimeException) cause).getStatus().getCode()).isEqualTo(Status.INVALID_ARGUMENT.getCode());
50+
}
51+
52+
@Test
53+
void validation_success_test() {
54+
HelloServiceBlockingStub client = GrpcClients.builder(server.httpUri())
55+
.build(HelloServiceBlockingStub.class);
56+
57+
Hello.HelloResponse response = client.hello(
58+
Hello.HelloRequest.newBuilder()
59+
.setMessage("success")
60+
.build()
61+
);
62+
63+
assertThat(response.getMessage()).isEqualTo("success");
64+
}
65+
66+
private static class HelloRequestValidator implements RequestValidator<Hello.HelloRequest> {
67+
68+
@Override
69+
public ValidationResult isValid(Hello.HelloRequest request) {
70+
if (request.getMessage().equals("success")) {
71+
return new ValidationResult(true, null);
72+
}
73+
74+
return new ValidationResult(false, ERROR_MESSAGE);
75+
}
76+
}
77+
78+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2022 LINE Corporation
2+
//
3+
// LINE Corporation licenses this file to you under the Apache License,
4+
// version 2.0 (the "License"); you may not use this file except in compliance
5+
// with the License. You may obtain a copy of the License at:
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
// License for the specific language governing permissions and limitations
13+
// under the License.
14+
15+
syntax = "proto3";
16+
17+
package armeria.grpc.testing;
18+
19+
option java_package = "com.linecorp.armeria.grpc.testing";
20+
21+
import "google/api/annotations.proto";
22+
23+
service HelloService {
24+
rpc hello (HelloRequest) returns (HelloResponse) {}
25+
}
26+
27+
message HelloRequest {
28+
string message = 1;
29+
}
30+
31+
message HelloResponse {
32+
string message = 1;
33+
}

0 commit comments

Comments
 (0)