Skip to content

Commit 872267c

Browse files
Provide Runtime Hints for Beans used in Pre/PostAuthorize Expressions
Closes spring-projectsgh-14652
1 parent ce54a6d commit 872267c

File tree

5 files changed

+450
-0
lines changed

5 files changed

+450
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
/*
2+
* Copyright 2002-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.annotation.method.configuration;
18+
19+
import java.lang.annotation.Annotation;
20+
import java.lang.reflect.AnnotatedElement;
21+
import java.lang.reflect.Field;
22+
import java.lang.reflect.Method;
23+
import java.util.ArrayList;
24+
import java.util.Arrays;
25+
import java.util.HashSet;
26+
import java.util.List;
27+
import java.util.Set;
28+
import java.util.function.Function;
29+
import java.util.stream.Collectors;
30+
import java.util.stream.Stream;
31+
32+
import org.apache.commons.logging.Log;
33+
import org.apache.commons.logging.LogFactory;
34+
35+
import org.springframework.aot.generate.GenerationContext;
36+
import org.springframework.aot.hint.MemberCategory;
37+
import org.springframework.aot.hint.RuntimeHints;
38+
import org.springframework.aot.hint.TypeReference;
39+
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
40+
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotContribution;
41+
import org.springframework.beans.factory.aot.BeanFactoryInitializationAotProcessor;
42+
import org.springframework.beans.factory.aot.BeanFactoryInitializationCode;
43+
import org.springframework.beans.factory.config.BeanDefinition;
44+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
45+
import org.springframework.beans.factory.support.RegisteredBean;
46+
import org.springframework.core.annotation.AnnotationConfigurationException;
47+
import org.springframework.core.annotation.MergedAnnotation;
48+
import org.springframework.core.annotation.MergedAnnotations;
49+
import org.springframework.core.annotation.RepeatableContainers;
50+
import org.springframework.core.log.LogMessage;
51+
import org.springframework.expression.spel.SpelNode;
52+
import org.springframework.expression.spel.ast.BeanReference;
53+
import org.springframework.expression.spel.standard.SpelExpression;
54+
import org.springframework.expression.spel.standard.SpelExpressionParser;
55+
import org.springframework.security.access.prepost.PostAuthorize;
56+
import org.springframework.security.access.prepost.PreAuthorize;
57+
import org.springframework.util.ReflectionUtils;
58+
59+
/**
60+
* AOT BeanFactoryInitializationAotProcessor that detects the presence of
61+
* {@link PreAuthorize} and {@link PostAuthorize} on annotated elements of all registered
62+
* beans and register runtime hints for the beans used within the security expressions.
63+
*
64+
* @author Marcus da Coregio
65+
* @since 6.3
66+
*/
67+
class PrePostAuthorizeBeanFactoryInitializationAotProcessor implements BeanFactoryInitializationAotProcessor {
68+
69+
@Override
70+
public BeanFactoryInitializationAotContribution processAheadOfTime(ConfigurableListableBeanFactory beanFactory) {
71+
Class<?>[] beanTypes = Arrays.stream(beanFactory.getBeanDefinitionNames())
72+
.map((beanName) -> RegisteredBean.of(beanFactory, beanName).getBeanClass())
73+
.toArray(Class<?>[]::new);
74+
return new PrePostAuthorizeContribution(beanTypes, beanFactory);
75+
}
76+
77+
private static class PrePostAuthorizeContribution implements BeanFactoryInitializationAotContribution {
78+
79+
private final Log logger = LogFactory.getLog(getClass());
80+
81+
private final Class<?>[] types;
82+
83+
private final ConfigurableListableBeanFactory beanFactory;
84+
85+
private final SpelExpressionParser expressionParser = new SpelExpressionParser();
86+
87+
PrePostAuthorizeContribution(Class<?>[] types, ConfigurableListableBeanFactory beanFactory) {
88+
this.types = types;
89+
this.beanFactory = beanFactory;
90+
}
91+
92+
@Override
93+
public void applyTo(GenerationContext generationContext,
94+
BeanFactoryInitializationCode beanFactoryInitializationCode) {
95+
List<PreAuthorize> preAuthorizes = new ArrayList<>();
96+
List<PostAuthorize> postAuthorizes = new ArrayList<>();
97+
for (Class<?> type : this.types) {
98+
preAuthorizes.addAll(collectAnnotations(type, PreAuthorize.class));
99+
postAuthorizes.addAll(collectAnnotations(type, PostAuthorize.class));
100+
}
101+
Set<String> expressions = Stream
102+
.concat(preAuthorizes.stream().map(PreAuthorize::value),
103+
postAuthorizes.stream().map(PostAuthorize::value))
104+
.collect(Collectors.toSet());
105+
Set<String> beanNames = new HashSet<>();
106+
for (String expr : expressions) {
107+
beanNames.addAll(extractBeanNames(expr));
108+
}
109+
registerHints(beanNames, generationContext.getRuntimeHints());
110+
}
111+
112+
private void registerHints(Set<String> beanNames, RuntimeHints runtimeHints) {
113+
for (String beanName : beanNames) {
114+
try {
115+
BeanDefinition definition = this.beanFactory.getBeanDefinition(beanName);
116+
runtimeHints.reflection()
117+
.registerType(TypeReference.of(definition.getBeanClassName()),
118+
MemberCategory.INVOKE_DECLARED_METHODS);
119+
}
120+
catch (NoSuchBeanDefinitionException ex) {
121+
this.logger.debug(LogMessage.format(
122+
"""
123+
Could not register runtime hints for bean with name [%s] because it is not available, please provide
124+
the needed hints manually""",
125+
beanName));
126+
}
127+
}
128+
}
129+
130+
private <A extends Annotation> List<A> collectAnnotations(Class<?> type, Class<A> annotationType) {
131+
List<A> annotations = new ArrayList<>();
132+
A classAnnotation = findDistinctAnnotation(type, annotationType, MergedAnnotation::synthesize);
133+
if (classAnnotation != null) {
134+
annotations.add(classAnnotation);
135+
}
136+
for (Method method : type.getDeclaredMethods()) {
137+
A methodAnnotation = findDistinctAnnotation(method, annotationType, MergedAnnotation::synthesize);
138+
if (methodAnnotation != null) {
139+
annotations.add(methodAnnotation);
140+
}
141+
}
142+
return annotations;
143+
}
144+
145+
private Set<String> extractBeanNames(String rawExpression) {
146+
SpelExpression expression = this.expressionParser.parseRaw(rawExpression);
147+
SpelNode node = expression.getAST();
148+
Set<String> beanNames = new HashSet<>();
149+
resolveBeanNames(beanNames, node);
150+
return beanNames;
151+
}
152+
153+
private void resolveBeanNames(Set<String> beanNames, SpelNode node) {
154+
if (node instanceof BeanReference br) {
155+
beanNames.add(resolveBeanName(br));
156+
}
157+
int childCount = node.getChildCount();
158+
if (childCount == 0) {
159+
return;
160+
}
161+
for (int i = 0; i < childCount; i++) {
162+
resolveBeanNames(beanNames, node.getChild(i));
163+
}
164+
}
165+
166+
private String resolveBeanName(BeanReference br) {
167+
try {
168+
Field field = ReflectionUtils.findField(BeanReference.class, "beanName");
169+
field.setAccessible(true);
170+
return (String) field.get(br);
171+
}
172+
catch (IllegalAccessException ex) {
173+
throw new IllegalStateException("Could not resolve beanName for BeanReference [%s]".formatted(br), ex);
174+
}
175+
}
176+
177+
private static <A extends Annotation> A findDistinctAnnotation(AnnotatedElement annotatedElement,
178+
Class<A> annotationType, Function<MergedAnnotation<A>, A> map) {
179+
MergedAnnotations mergedAnnotations = MergedAnnotations.from(annotatedElement,
180+
MergedAnnotations.SearchStrategy.TYPE_HIERARCHY, RepeatableContainers.none());
181+
List<A> annotations = mergedAnnotations.stream(annotationType)
182+
.map(MergedAnnotation::withNonMergedAttributes)
183+
.map(map)
184+
.distinct()
185+
.toList();
186+
187+
return switch (annotations.size()) {
188+
case 0 -> null;
189+
case 1 -> annotations.get(0);
190+
default -> throw new AnnotationConfigurationException("""
191+
Please ensure there is one unique annotation of type @%s attributed to %s. \
192+
Found %d competing annotations: %s""".formatted(annotationType.getName(), annotatedElement,
193+
annotations.size(), annotations));
194+
};
195+
}
196+
197+
}
198+
199+
}

config/src/main/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfiguration.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ static MethodInterceptor postFilterAuthorizationMethodInterceptor(
152152
});
153153
}
154154

155+
@Bean
156+
@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
157+
static PrePostAuthorizeBeanFactoryInitializationAotProcessor prePostAuthorizeBeanFactoryInitializationAotProcessor() {
158+
return new PrePostAuthorizeBeanFactoryInitializationAotProcessor();
159+
}
160+
155161
private static MethodSecurityExpressionHandler defaultExpressionHandler(
156162
ObjectProvider<GrantedAuthorityDefaults> defaultsProvider,
157163
ObjectProvider<RoleHierarchy> roleHierarchyProvider, ApplicationContext context) {

0 commit comments

Comments
 (0)