Skip to content

Commit 2f3cf56

Browse files
committed
Fix servlet component scanning in a mock web environment
Closes gh-39736
1 parent b961662 commit 2f3cf56

File tree

2 files changed

+149
-3
lines changed

2 files changed

+149
-3
lines changed

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/web/servlet/ServletComponentRegisteringPostProcessor.java

+9-3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
import org.springframework.context.ApplicationContext;
3838
import org.springframework.context.ApplicationContextAware;
3939
import org.springframework.context.annotation.ClassPathScanningCandidateComponentProvider;
40+
import org.springframework.mock.web.MockServletContext;
41+
import org.springframework.util.ClassUtils;
4042
import org.springframework.web.context.WebApplicationContext;
4143

4244
/**
@@ -50,6 +52,9 @@
5052
class ServletComponentRegisteringPostProcessor
5153
implements BeanFactoryPostProcessor, ApplicationContextAware, BeanFactoryInitializationAotProcessor {
5254

55+
private static final boolean MOCK_SERVLET_CONTEXT_AVAILABLE = ClassUtils
56+
.isPresent("org.springframework.mock.web.MockServletContext", null);
57+
5358
private static final List<ServletComponentHandler> HANDLERS;
5459

5560
static {
@@ -70,7 +75,7 @@ class ServletComponentRegisteringPostProcessor
7075

7176
@Override
7277
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
73-
if (isRunningInEmbeddedWebServer()) {
78+
if (eligibleForServletComponentScanning()) {
7479
ClassPathScanningCandidateComponentProvider componentProvider = createComponentProvider();
7580
for (String packageToScan : this.packagesToScan) {
7681
scanPackage(componentProvider, packageToScan);
@@ -88,9 +93,10 @@ private void scanPackage(ClassPathScanningCandidateComponentProvider componentPr
8893
}
8994
}
9095

91-
private boolean isRunningInEmbeddedWebServer() {
96+
private boolean eligibleForServletComponentScanning() {
9297
return this.applicationContext instanceof WebApplicationContext webApplicationContext
93-
&& webApplicationContext.getServletContext() == null;
98+
&& (webApplicationContext.getServletContext() == null || (MOCK_SERVLET_CONTEXT_AVAILABLE
99+
&& webApplicationContext.getServletContext() instanceof MockServletContext));
94100
}
95101

96102
private ClassPathScanningCandidateComponentProvider createComponentProvider() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Copyright 2012-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.boot.web.servlet;
18+
19+
import java.io.File;
20+
import java.io.FileWriter;
21+
import java.io.IOException;
22+
import java.net.URL;
23+
import java.net.URLClassLoader;
24+
import java.util.Map;
25+
import java.util.Properties;
26+
27+
import jakarta.servlet.MultipartConfigElement;
28+
import jakarta.servlet.annotation.WebFilter;
29+
import jakarta.servlet.annotation.WebListener;
30+
import jakarta.servlet.annotation.WebServlet;
31+
import org.junit.jupiter.api.AfterEach;
32+
import org.junit.jupiter.api.Test;
33+
import org.junit.jupiter.api.io.TempDir;
34+
35+
import org.springframework.boot.testsupport.classpath.ForkedClassPath;
36+
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebApplicationContext;
37+
import org.springframework.boot.web.servlet.testcomponents.filter.TestFilter;
38+
import org.springframework.boot.web.servlet.testcomponents.listener.TestListener;
39+
import org.springframework.boot.web.servlet.testcomponents.servlet.TestMultipartServlet;
40+
import org.springframework.boot.web.servlet.testcomponents.servlet.TestServlet;
41+
import org.springframework.mock.web.MockServletContext;
42+
43+
import static org.assertj.core.api.Assertions.assertThat;
44+
import static org.mockito.BDDMockito.then;
45+
import static org.mockito.Mockito.mock;
46+
47+
/**
48+
* Integration tests for {@link ServletComponentScan @ServletComponentScan} with a mock
49+
* web environment.
50+
*
51+
* @author Andy Wilkinson
52+
*/
53+
class MockWebEnvironmentServletComponentScanIntegrationTests {
54+
55+
private AnnotationConfigServletWebApplicationContext context;
56+
57+
@TempDir
58+
File temp;
59+
60+
@AfterEach
61+
void cleanUp() {
62+
if (this.context != null) {
63+
this.context.close();
64+
}
65+
}
66+
67+
@Test
68+
@ForkedClassPath
69+
void componentsAreRegistered() {
70+
prepareContext();
71+
this.context.refresh();
72+
Map<String, RegistrationBean> registrationBeans = this.context.getBeansOfType(RegistrationBean.class);
73+
assertThat(registrationBeans).hasSize(3);
74+
assertThat(registrationBeans.keySet()).containsExactlyInAnyOrder(TestServlet.class.getName(),
75+
TestFilter.class.getName(), TestMultipartServlet.class.getName());
76+
WebListenerRegistry registry = mock(WebListenerRegistry.class);
77+
this.context.getBean(WebListenerRegistrar.class).register(registry);
78+
then(registry).should().addWebListeners(TestListener.class.getName());
79+
}
80+
81+
@Test
82+
@ForkedClassPath
83+
void indexedComponentsAreRegistered() throws IOException {
84+
writeIndex(this.temp);
85+
prepareContext();
86+
try (URLClassLoader classLoader = new URLClassLoader(new URL[] { this.temp.toURI().toURL() },
87+
getClass().getClassLoader())) {
88+
this.context.setClassLoader(classLoader);
89+
this.context.refresh();
90+
Map<String, RegistrationBean> registrationBeans = this.context.getBeansOfType(RegistrationBean.class);
91+
assertThat(registrationBeans).hasSize(2);
92+
assertThat(registrationBeans.keySet()).containsExactlyInAnyOrder(TestServlet.class.getName(),
93+
TestFilter.class.getName());
94+
WebListenerRegistry registry = mock(WebListenerRegistry.class);
95+
this.context.getBean(WebListenerRegistrar.class).register(registry);
96+
then(registry).should().addWebListeners(TestListener.class.getName());
97+
}
98+
}
99+
100+
@Test
101+
@ForkedClassPath
102+
void multipartConfigIsHonoured() {
103+
prepareContext();
104+
this.context.refresh();
105+
@SuppressWarnings("rawtypes")
106+
Map<String, ServletRegistrationBean> beans = this.context.getBeansOfType(ServletRegistrationBean.class);
107+
ServletRegistrationBean<?> servletRegistrationBean = beans.get(TestMultipartServlet.class.getName());
108+
assertThat(servletRegistrationBean).isNotNull();
109+
MultipartConfigElement multipartConfig = servletRegistrationBean.getMultipartConfig();
110+
assertThat(multipartConfig).isNotNull();
111+
assertThat(multipartConfig.getLocation()).isEqualTo("test");
112+
assertThat(multipartConfig.getMaxRequestSize()).isEqualTo(2048);
113+
assertThat(multipartConfig.getMaxFileSize()).isEqualTo(1024);
114+
assertThat(multipartConfig.getFileSizeThreshold()).isEqualTo(512);
115+
}
116+
117+
private void writeIndex(File temp) throws IOException {
118+
File metaInf = new File(temp, "META-INF");
119+
metaInf.mkdirs();
120+
Properties index = new Properties();
121+
index.setProperty(TestFilter.class.getName(), WebFilter.class.getName());
122+
index.setProperty(TestListener.class.getName(), WebListener.class.getName());
123+
index.setProperty(TestServlet.class.getName(), WebServlet.class.getName());
124+
try (FileWriter writer = new FileWriter(new File(metaInf, "spring.components"))) {
125+
index.store(writer, null);
126+
}
127+
}
128+
129+
private void prepareContext() {
130+
this.context = new AnnotationConfigServletWebApplicationContext();
131+
this.context.register(ScanningConfiguration.class);
132+
this.context.setServletContext(new MockServletContext());
133+
}
134+
135+
@ServletComponentScan(basePackages = "org.springframework.boot.web.servlet.testcomponents")
136+
static class ScanningConfiguration {
137+
138+
}
139+
140+
}

0 commit comments

Comments
 (0)