Skip to content

Commit 70f07bd

Browse files
authored
bugfix: do not report violation on blocked request (#66)
1 parent 0055dd4 commit 70f07bd

File tree

14 files changed

+364
-70
lines changed

14 files changed

+364
-70
lines changed

openapi-validation-core/src/main/java/com/getyourguide/openapi/validation/core/OpenApiRequestValidationConfiguration.java

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
public class OpenApiRequestValidationConfiguration {
99
private double sampleRate;
1010
private int validationReportThrottleWaitSeconds;
11+
private boolean shouldFailOnRequestViolation;
1112
}

openapi-validation-core/src/main/java/com/getyourguide/openapi/validation/core/OpenApiRequestValidator.java

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public class OpenApiRequestValidator {
2323
private final ThreadPoolExecutor threadPoolExecutor;
2424
private final OpenApiInteractionValidatorWrapper validator;
2525
private final ValidationReportHandler validationReportHandler;
26+
private final OpenApiRequestValidationConfiguration configuration;
2627

2728
public OpenApiRequestValidator(
2829
ThreadPoolExecutor threadPoolExecutor,
@@ -34,6 +35,7 @@ public OpenApiRequestValidator(
3435
this.threadPoolExecutor = threadPoolExecutor;
3536
this.validator = validator;
3637
this.validationReportHandler = validationReportHandler;
38+
this.configuration = configuration;
3739

3840
metricsReporter.reportStartup(
3941
validator != null,
@@ -74,7 +76,12 @@ public ValidationResult validateRequestObject(
7476
try {
7577
var simpleRequest = buildSimpleRequest(request, requestBody);
7678
var result = validator.validateRequest(simpleRequest);
77-
validationReportHandler.handleValidationReport(request, response, Direction.REQUEST, requestBody, result);
79+
// TODO this should not be done here, but currently the only way to do it -> Refactor this so that logging
80+
// is actually done in the interceptor/filter where logging can easily be skipped then.
81+
if (!configuration.isShouldFailOnRequestViolation()) {
82+
validationReportHandler
83+
.handleValidationReport(request, response, Direction.REQUEST, requestBody, result);
84+
}
7885
return buildValidationResult(result);
7986
} catch (Exception e) {
8087
log.error("Could not validate request", e);

spring-boot-starter/spring-boot-starter-core/src/main/java/com/getyourguide/openapi/validation/OpenApiValidationApplicationProperties.java

+1
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ public OpenApiRequestValidationConfiguration toOpenApiRequestValidationConfigura
8888
return OpenApiRequestValidationConfiguration.builder()
8989
.sampleRate(getSampleRate())
9090
.validationReportThrottleWaitSeconds(getValidationReportThrottleWaitSeconds())
91+
.shouldFailOnRequestViolation(getShouldFailOnRequestViolation() != null && getShouldFailOnRequestViolation())
9192
.build();
9293
}
9394
}
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
package com.getyourguide.openapi.validation.factory;
22

3+
import com.getyourguide.openapi.validation.filter.MultiReadContentCachingRequestWrapper;
34
import jakarta.servlet.http.HttpServletRequest;
45
import jakarta.servlet.http.HttpServletResponse;
56
import javax.annotation.Nullable;
6-
import org.springframework.web.util.ContentCachingRequestWrapper;
77
import org.springframework.web.util.ContentCachingResponseWrapper;
88
import org.springframework.web.util.WebUtils;
99

1010
public class ContentCachingWrapperFactory {
11-
public ContentCachingRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
12-
if (request instanceof ContentCachingRequestWrapper) {
13-
return (ContentCachingRequestWrapper) request;
11+
public MultiReadContentCachingRequestWrapper buildContentCachingRequestWrapper(HttpServletRequest request) {
12+
if (request instanceof MultiReadContentCachingRequestWrapper) {
13+
return (MultiReadContentCachingRequestWrapper) request;
1414
}
1515

16-
return new ContentCachingRequestWrapper(request);
16+
return new MultiReadContentCachingRequestWrapper(request);
1717
}
1818

1919
public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServletResponse response) {
@@ -26,12 +26,12 @@ public ContentCachingResponseWrapper buildContentCachingResponseWrapper(HttpServ
2626
}
2727

2828
@Nullable
29-
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
30-
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
29+
public MultiReadContentCachingRequestWrapper getCachingRequest(HttpServletRequest request) {
30+
return request instanceof MultiReadContentCachingRequestWrapper ? (MultiReadContentCachingRequestWrapper) request : null;
3131
}
3232

3333
@Nullable
34-
public ContentCachingRequestWrapper getCachingRequest(HttpServletRequest request) {
35-
return request instanceof ContentCachingRequestWrapper ? (ContentCachingRequestWrapper) request : null;
34+
public ContentCachingResponseWrapper getCachingResponse(final HttpServletResponse response) {
35+
return WebUtils.getNativeResponse(response, ContentCachingResponseWrapper.class);
3636
}
3737
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package com.getyourguide.openapi.validation.filter;
2+
3+
import jakarta.servlet.ReadListener;
4+
import jakarta.servlet.ServletInputStream;
5+
import jakarta.servlet.http.HttpServletRequest;
6+
import java.io.BufferedReader;
7+
import java.io.ByteArrayInputStream;
8+
import java.io.IOException;
9+
import java.io.InputStreamReader;
10+
import org.springframework.web.util.ContentCachingRequestWrapper;
11+
12+
public class MultiReadContentCachingRequestWrapper extends ContentCachingRequestWrapper {
13+
14+
public MultiReadContentCachingRequestWrapper(HttpServletRequest request) {
15+
super(request);
16+
}
17+
18+
public MultiReadContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) {
19+
super(request, contentCacheLimit);
20+
}
21+
22+
@Override
23+
public ServletInputStream getInputStream() throws IOException {
24+
var inputStream = super.getInputStream();
25+
if (inputStream.isFinished()) {
26+
return new CachedServletInputStream(getContentAsByteArray());
27+
}
28+
29+
return inputStream;
30+
}
31+
32+
@Override
33+
public BufferedReader getReader() throws IOException {
34+
return new BufferedReader(new InputStreamReader(getInputStream()));
35+
}
36+
37+
private static class CachedServletInputStream extends ServletInputStream {
38+
private final ByteArrayInputStream buffer;
39+
40+
public CachedServletInputStream(byte[] contents) {
41+
this.buffer = new ByteArrayInputStream(contents);
42+
}
43+
44+
@Override
45+
public int read() throws IOException {
46+
return buffer.read();
47+
}
48+
49+
@Override
50+
public boolean isFinished() {
51+
return buffer.available() == 0;
52+
}
53+
54+
@Override
55+
public boolean isReady() {
56+
return true;
57+
}
58+
59+
@Override
60+
public void setReadListener(ReadListener listener) {
61+
throw new UnsupportedOperationException("Not implemented");
62+
}
63+
}
64+
}

spring-boot-starter/spring-boot-starter-web/src/main/java/com/getyourguide/openapi/validation/filter/OpenApiValidationInterceptor.java

+14-5
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@
99
import com.getyourguide.openapi.validation.factory.ServletMetaDataFactory;
1010
import jakarta.servlet.http.HttpServletRequest;
1111
import jakarta.servlet.http.HttpServletResponse;
12+
import java.io.IOException;
1213
import java.nio.charset.StandardCharsets;
1314
import javax.annotation.Nullable;
1415
import lombok.AllArgsConstructor;
1516
import lombok.extern.slf4j.Slf4j;
1617
import org.springframework.http.HttpStatusCode;
18+
import org.springframework.util.StreamUtils;
1719
import org.springframework.web.server.ResponseStatusException;
1820
import org.springframework.web.servlet.AsyncHandlerInterceptor;
1921
import org.springframework.web.servlet.ModelAndView;
20-
import org.springframework.web.util.ContentCachingRequestWrapper;
2122
import org.springframework.web.util.ContentCachingResponseWrapper;
2223

2324
@Slf4j
@@ -114,6 +115,8 @@ private void validateResponse(
114115
);
115116
// Note: validateResponseResult will always be null on ASYNC
116117
if (validateResponseResult == ValidationResult.INVALID) {
118+
response.reset();
119+
response.setStatus(500);
117120
throw new ResponseStatusException(HttpStatusCode.valueOf(500), "Response validation failed");
118121
}
119122
}
@@ -126,7 +129,7 @@ private static RequestMetaData getRequestMetaData(HttpServletRequest request) {
126129
}
127130

128131
private ValidationResult validateRequest(
129-
ContentCachingRequestWrapper request,
132+
MultiReadContentCachingRequestWrapper request,
130133
RequestMetaData requestMetaData,
131134
@Nullable ResponseMetaData responseMetaData,
132135
RunType runType
@@ -137,9 +140,7 @@ private ValidationResult validateRequest(
137140
return ValidationResult.NOT_APPLICABLE;
138141
}
139142

140-
var requestBody = request.getContentType() != null
141-
? new String(request.getContentAsByteArray(), StandardCharsets.UTF_8)
142-
: null;
143+
var requestBody = request.getContentType() != null ? readBodyCatchingException(request) : null;
143144

144145
if (runType == RunType.ASYNC) {
145146
validator.validateRequestObjectAsync(requestMetaData, responseMetaData, requestBody);
@@ -149,6 +150,14 @@ private ValidationResult validateRequest(
149150
}
150151
}
151152

153+
private static String readBodyCatchingException(MultiReadContentCachingRequestWrapper request) {
154+
try {
155+
return StreamUtils.copyToString(request.getInputStream(), StandardCharsets.UTF_8);
156+
} catch (IOException e) {
157+
return null;
158+
}
159+
}
160+
152161
private ValidationResult validateResponse(
153162
HttpServletRequest request,
154163
ContentCachingResponseWrapper response,

spring-boot-starter/spring-boot-starter-web/src/test/java/com/getyourguide/openapi/validation/filter/BaseFilterTest.java

+18-11
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import jakarta.servlet.ServletResponse;
1717
import jakarta.servlet.http.HttpServletRequest;
1818
import jakarta.servlet.http.HttpServletResponse;
19+
import java.io.ByteArrayInputStream;
20+
import java.io.IOException;
1921
import java.nio.charset.StandardCharsets;
2022
import java.util.HashMap;
2123
import lombok.Builder;
2224
import org.mockito.Mockito;
23-
import org.springframework.web.util.ContentCachingRequestWrapper;
25+
import org.springframework.mock.web.DelegatingServletInputStream;
2426
import org.springframework.web.util.ContentCachingResponseWrapper;
2527

2628
public class BaseFilterTest {
@@ -48,15 +50,12 @@ private static void mockRequestAttributes(ServletRequest request, HashMap<String
4850
}
4951

5052
protected MockSetupData mockSetup(MockConfiguration configuration) {
51-
var request = mock(ContentCachingRequestWrapper.class);
53+
var request = mock(MultiReadContentCachingRequestWrapper.class);
5254
var response = mock(ContentCachingResponseWrapper.class);
5355
var cachingRequest = mockContentCachingRequest(request, configuration);
5456
var cachingResponse = mockContentCachingResponse(response, configuration);
5557
mockRequestAttributes(request, cachingRequest);
5658

57-
when(request.getContentType()).thenReturn("application/json");
58-
when(request.getContentAsByteArray()).thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
59-
6059
when(response.getContentType()).thenReturn("application/json");
6160
when(response.getContentAsByteArray()).thenReturn(configuration.responseBody.getBytes(StandardCharsets.UTF_8));
6261

@@ -102,16 +101,24 @@ private ContentCachingResponseWrapper mockContentCachingResponse(
102101
return cachingResponse;
103102
}
104103

105-
private ContentCachingRequestWrapper mockContentCachingRequest(
104+
private MultiReadContentCachingRequestWrapper mockContentCachingRequest(
106105
HttpServletRequest request,
107106
MockConfiguration configuration
108107
) {
109-
var cachingRequest = mock(ContentCachingRequestWrapper.class);
108+
var cachingRequest = mock(MultiReadContentCachingRequestWrapper.class);
110109
when(contentCachingWrapperFactory.buildContentCachingRequestWrapper(request)).thenReturn(cachingRequest);
111-
if (configuration.responseBody != null) {
112-
when(cachingRequest.getContentType()).thenReturn("application/json");
113-
when(cachingRequest.getContentAsByteArray())
114-
.thenReturn(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
110+
if (configuration.requestBody != null) {
111+
try {
112+
var sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
113+
when(request.getContentType()).thenReturn("application/json");
114+
when(request.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));
115+
116+
sourceStream = new ByteArrayInputStream(configuration.requestBody.getBytes(StandardCharsets.UTF_8));
117+
when(cachingRequest.getContentType()).thenReturn("application/json");
118+
when(cachingRequest.getInputStream()).thenReturn(new DelegatingServletInputStream(sourceStream));
119+
} catch (IOException e) {
120+
throw new IllegalStateException(e);
121+
}
115122
}
116123
return cachingRequest;
117124
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package com.getyourguide.openapi.validation.integration;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.mockito.ArgumentMatchers.any;
5+
import static org.mockito.Mockito.never;
6+
import static org.mockito.Mockito.verify;
7+
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
8+
import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
9+
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
10+
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
11+
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
12+
13+
import com.getyourguide.openapi.validation.integration.controller.DefaultRestController;
14+
import com.getyourguide.openapi.validation.test.TestViolationLogger;
15+
import java.util.Optional;
16+
import org.hamcrest.Matchers;
17+
import org.junit.jupiter.api.BeforeEach;
18+
import org.junit.jupiter.api.Test;
19+
import org.junit.jupiter.api.extension.ExtendWith;
20+
import org.springframework.beans.factory.annotation.Autowired;
21+
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc;
22+
import org.springframework.boot.test.context.SpringBootTest;
23+
import org.springframework.boot.test.mock.mockito.SpyBean;
24+
import org.springframework.http.MediaType;
25+
import org.springframework.test.context.junit.jupiter.SpringExtension;
26+
import org.springframework.test.web.servlet.MockMvc;
27+
28+
@SpringBootTest(properties = {
29+
"openapi.validation.should-fail-on-request-violation=true",
30+
"openapi.validation.should-fail-on-response-violation=true",
31+
})
32+
@AutoConfigureMockMvc
33+
@ExtendWith(SpringExtension.class)
34+
public class FailOnViolationIntegrationTest {
35+
36+
@Autowired
37+
private MockMvc mockMvc;
38+
39+
@Autowired
40+
private TestViolationLogger openApiViolationLogger;
41+
42+
@SpyBean
43+
private DefaultRestController defaultRestController;
44+
45+
@BeforeEach
46+
public void setup() {
47+
openApiViolationLogger.clearViolations();
48+
}
49+
50+
@Test
51+
public void whenValidRequestThenReturnSuccessfully() throws Exception {
52+
mockMvc.perform(post("/test")
53+
.content("{ \"value\": \"testing\", \"responseStatusCode\": 200 }").contentType(MediaType.APPLICATION_JSON))
54+
.andExpectAll(
55+
status().isOk(),
56+
jsonPath("$.value").value("testing")
57+
);
58+
Thread.sleep(100);
59+
60+
assertEquals(0, openApiViolationLogger.getViolations().size());
61+
verify(defaultRestController).postTest(any());
62+
}
63+
64+
@Test
65+
public void whenInvalidRequestThenReturn400AndNoViolationLogged() throws Exception {
66+
mockMvc.perform(post("/test").content("{ \"value\": 1 }").contentType(MediaType.APPLICATION_JSON))
67+
.andExpectAll(
68+
status().is4xxClientError(),
69+
content().string(Matchers.blankOrNullString())
70+
);
71+
Thread.sleep(100);
72+
73+
assertEquals(0, openApiViolationLogger.getViolations().size());
74+
verify(defaultRestController, never()).postTest(any());
75+
// TODO check that something else gets logged?
76+
}
77+
78+
@Test
79+
public void whenInvalidResponseThenReturn500AndViolationLogged() throws Exception {
80+
mockMvc.perform(get("/test").queryParam("value", "invalid-response-value!"))
81+
.andExpectAll(
82+
status().is5xxServerError(),
83+
content().string(Matchers.blankOrNullString())
84+
);
85+
Thread.sleep(100);
86+
87+
assertEquals(1, openApiViolationLogger.getViolations().size());
88+
var violation = openApiViolationLogger.getViolations().get(0);
89+
assertEquals("validation.response.body.schema.pattern", violation.getRule());
90+
assertEquals(Optional.of(200), violation.getResponseStatus());
91+
assertEquals(Optional.of("/value"), violation.getInstance());
92+
verify(defaultRestController).getTest(any(), any(), any());
93+
}
94+
}

spring-boot-starter/spring-boot-starter-web/src/test/java/com/getyourguide/openapi/validation/integration/OpenApiValidationIntegrationTest.java

-2
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ public void whenTestOptionsCallThenShouldNotValidate() throws Exception {
123123
assertEquals(0, openApiViolationLogger.getViolations().size());
124124
}
125125

126-
// TODO Add test that fails on request violation immediately (maybe needs separate test class & setup) should not log violation
127-
128126
@Nullable
129127
private OpenApiViolation getViolationByRule(List<OpenApiViolation> violations, String rule) {
130128
return violations.stream()

0 commit comments

Comments
 (0)