Skip to content

Commit 2713075

Browse files
committed
Mark Observations with Firewall Failures
Closes gh-11994
1 parent 46ab846 commit 2713075

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

Diff for: config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java

+6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
5858
import org.springframework.security.web.debug.DebugFilter;
5959
import org.springframework.security.web.firewall.HttpFirewall;
60+
import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
6061
import org.springframework.security.web.firewall.RequestRejectedHandler;
6162
import org.springframework.security.web.firewall.StrictHttpFirewall;
6263
import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -307,6 +308,10 @@ protected Filter performBuild() throws Exception {
307308
if (this.requestRejectedHandler != null) {
308309
filterChainProxy.setRequestRejectedHandler(this.requestRejectedHandler);
309310
}
311+
else if (!this.observationRegistry.isNoop()) {
312+
filterChainProxy
313+
.setRequestRejectedHandler(new ObservationMarkingRequestRejectedHandler(this.observationRegistry));
314+
}
310315
filterChainProxy.setFilterChainDecorator(getFilterChainDecorator());
311316
filterChainProxy.afterPropertiesSet();
312317

@@ -319,6 +324,7 @@ protected Filter performBuild() throws Exception {
319324
+ "********************************************************************\n\n");
320325
result = new DebugFilter(filterChainProxy);
321326
}
327+
322328
this.postBuildAction.run();
323329
return result;
324330
}

Diff for: config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public BeanDefinition parse(Element element, ParserContext pc) {
4040
pc.getReaderContext().error("ref attribute is required", pc.extractSource(element));
4141
}
4242
// Ensure the FCP is registered.
43-
HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
43+
HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, element);
4444
BeanDefinition filterChainProxy = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY);
4545
filterChainProxy.getPropertyValues().addPropertyValue("firewall", new RuntimeBeanReference(ref));
4646
return null;

Diff for: config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

+19-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.springframework.security.web.FilterChainProxy;
5959
import org.springframework.security.web.ObservationFilterChainDecorator;
6060
import org.springframework.security.web.PortResolverImpl;
61+
import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
6162
import org.springframework.security.web.util.matcher.AnyRequestMatcher;
6263
import org.springframework.util.StringUtils;
6364
import org.springframework.util.xml.DomUtils;
@@ -120,7 +121,7 @@ public BeanDefinition parse(Element element, ParserContext pc) {
120121
CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(),
121122
pc.extractSource(element));
122123
pc.pushContainingComponent(compositeDef);
123-
registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
124+
registerFilterChainProxyIfNecessary(pc, element);
124125
// Obtain the filter chains and add the new chain to it
125126
BeanDefinition listFactoryBean = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAINS);
126127
List<BeanReference> filterChains = (List<BeanReference>) listFactoryBean.getPropertyValues()
@@ -351,7 +352,8 @@ else if (StringUtils.hasText(before)) {
351352
return customFilters;
352353
}
353354

354-
static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) {
355+
static void registerFilterChainProxyIfNecessary(ParserContext pc, Element element) {
356+
Object source = pc.extractSource(element);
355357
BeanDefinitionRegistry registry = pc.getRegistry();
356358
if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) {
357359
return;
@@ -378,6 +380,7 @@ static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source)
378380
requestRejected.addConstructorArgValue("requestRejectedHandler");
379381
requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY);
380382
requestRejected.addConstructorArgValue("requestRejectedHandler");
383+
requestRejected.addPropertyValue("observationRegistry", getObservationRegistry(element));
381384
AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition();
382385
String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean);
383386
registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean);
@@ -391,14 +394,16 @@ private static BeanMetadataElement getObservationRegistry(Element methodSecurity
391394
return BeanDefinitionBuilder.rootBeanDefinition(ObservationRegistryFactory.class).getBeanDefinition();
392395
}
393396

394-
static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
397+
public static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
395398

396399
private final String beanName;
397400

398401
private final String targetBeanName;
399402

400403
private final String targetPropertyName;
401404

405+
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
406+
402407
RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) {
403408
this.beanName = beanName;
404409
this.targetBeanName = targetBeanName;
@@ -412,13 +417,24 @@ public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) t
412417
beanDefinition.getPropertyValues().add(this.targetPropertyName,
413418
new RuntimeBeanReference(this.beanName));
414419
}
420+
else if (!this.observationRegistry.isNoop()) {
421+
BeanDefinition observable = BeanDefinitionBuilder
422+
.rootBeanDefinition(ObservationMarkingRequestRejectedHandler.class)
423+
.addConstructorArgValue(this.observationRegistry).getBeanDefinition();
424+
BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName);
425+
beanDefinition.getPropertyValues().add(this.targetPropertyName, observable);
426+
}
415427
}
416428

417429
@Override
418430
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
419431

420432
}
421433

434+
public void setObservationRegistry(ObservationRegistry registry) {
435+
this.observationRegistry = registry;
436+
}
437+
422438
}
423439

424440
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.firewall;
18+
19+
import java.io.IOException;
20+
21+
import io.micrometer.observation.Observation;
22+
import io.micrometer.observation.ObservationRegistry;
23+
import jakarta.servlet.ServletException;
24+
import jakarta.servlet.http.HttpServletRequest;
25+
import jakarta.servlet.http.HttpServletResponse;
26+
27+
public final class ObservationMarkingRequestRejectedHandler implements RequestRejectedHandler {
28+
29+
private final ObservationRegistry registry;
30+
31+
public ObservationMarkingRequestRejectedHandler(ObservationRegistry registry) {
32+
this.registry = registry;
33+
}
34+
35+
@Override
36+
public void handle(HttpServletRequest request, HttpServletResponse response, RequestRejectedException exception)
37+
throws IOException, ServletException {
38+
Observation observation = this.registry.getCurrentObservation();
39+
if (observation != null) {
40+
observation.error(exception);
41+
}
42+
}
43+
44+
}

0 commit comments

Comments
 (0)