Skip to content

Commit dbf9b50

Browse files
author
Steve Riesenberg
committed
Simplify OAuth2 Client configuration
Issue gh-11783
1 parent af14d76 commit dbf9b50

11 files changed

+1555
-206
lines changed

Diff for: config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

+208-133
Large diffs are not rendered by default.

Diff for: config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java

+18-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.security.config.annotation.web.configurers.oauth2.client;
1818

19+
import org.springframework.context.ApplicationContext;
20+
import org.springframework.core.ResolvableType;
1921
import org.springframework.security.authentication.AuthenticationManager;
2022
import org.springframework.security.config.Customizer;
2123
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
@@ -307,7 +309,22 @@ private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> get
307309
if (this.accessTokenResponseClient != null) {
308310
return this.accessTokenResponseClient;
309311
}
310-
return new DefaultAuthorizationCodeTokenResponseClient();
312+
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
313+
OAuth2AuthorizationCodeGrantRequest.class);
314+
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> bean = getBeanOrNull(resolvableType);
315+
return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient();
316+
}
317+
318+
@SuppressWarnings("unchecked")
319+
private <T> T getBeanOrNull(ResolvableType type) {
320+
ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class);
321+
if (context != null) {
322+
String[] names = context.getBeanNamesForType(type);
323+
if (names.length == 1) {
324+
return (T) context.getBean(names[0]);
325+
}
326+
}
327+
return null;
311328
}
312329

313330
}

Diff for: config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,7 @@ public void init(B http) throws Exception {
330330
super.init(http);
331331
}
332332
}
333-
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = this.tokenEndpointConfig.accessTokenResponseClient;
334-
if (accessTokenResponseClient == null) {
335-
accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
336-
}
333+
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = getAccessTokenResponseClient();
337334
OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService = getOAuth2UserService();
338335
OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider = new OAuth2LoginAuthenticationProvider(
339336
accessTokenResponseClient, oauth2UserService);
@@ -441,6 +438,16 @@ private GrantedAuthoritiesMapper getGrantedAuthoritiesMapperBean() {
441438
return (!grantedAuthoritiesMapperMap.isEmpty() ? grantedAuthoritiesMapperMap.values().iterator().next() : null);
442439
}
443440

441+
private OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getAccessTokenResponseClient() {
442+
if (this.tokenEndpointConfig.accessTokenResponseClient != null) {
443+
return this.tokenEndpointConfig.accessTokenResponseClient;
444+
}
445+
ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
446+
OAuth2AuthorizationCodeGrantRequest.class);
447+
OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> bean = getBeanOrNull(resolvableType);
448+
return (bean != null) ? bean : new DefaultAuthorizationCodeTokenResponseClient();
449+
}
450+
444451
private OAuth2UserService<OidcUserRequest, OidcUser> getOidcUserService() {
445452
if (this.userInfoEndpointConfig.oidcUserService != null) {
446453
return this.userInfoEndpointConfig.oidcUserService;

Diff for: config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -420,6 +420,8 @@ private void registerOAuth2ClientPostProcessors() {
420420
this.pc.getReaderContext()
421421
.registerWithGeneratedName(new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class));
422422
}
423+
this.pc.getReaderContext()
424+
.registerWithGeneratedName(new RootBeanDefinition(OAuth2AuthorizedClientManagerRegistrar.class));
423425
}
424426

425427
private void createSaml2LoginFilter(BeanReference authenticationManager,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
/*
2+
* Copyright 2002-2023 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.config.http;
18+
19+
import java.util.ArrayList;
20+
import java.util.Collection;
21+
import java.util.List;
22+
import java.util.Set;
23+
import java.util.function.Consumer;
24+
25+
import org.springframework.beans.BeansException;
26+
import org.springframework.beans.factory.BeanFactory;
27+
import org.springframework.beans.factory.BeanFactoryAware;
28+
import org.springframework.beans.factory.BeanFactoryUtils;
29+
import org.springframework.beans.factory.BeanInitializationException;
30+
import org.springframework.beans.factory.ListableBeanFactory;
31+
import org.springframework.beans.factory.ObjectProvider;
32+
import org.springframework.beans.factory.config.BeanDefinition;
33+
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
34+
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
35+
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
36+
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
37+
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
38+
import org.springframework.core.ResolvableType;
39+
import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
40+
import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
41+
import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider;
42+
import org.springframework.security.oauth2.client.JwtBearerOAuth2AuthorizedClientProvider;
43+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
44+
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
45+
import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
46+
import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
47+
import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
48+
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
49+
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
50+
import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
51+
import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
52+
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
53+
import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
54+
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
55+
56+
/**
57+
* A registrar for registering the default {@link OAuth2AuthorizedClientManager} bean
58+
* definition, if not already present.
59+
* <p>
60+
* Note: This class is a direct copy of
61+
* {@link org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerRegistrar}.
62+
*
63+
* @author Joe Grandja
64+
* @author Steve Riesenberg
65+
* @since 6.2.0
66+
*/
67+
final class OAuth2AuthorizedClientManagerRegistrar implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
68+
69+
// @formatter:off
70+
private static final Set<Class<?>> KNOWN_AUTHORIZED_CLIENT_PROVIDERS = Set.of(
71+
AuthorizationCodeOAuth2AuthorizedClientProvider.class,
72+
RefreshTokenOAuth2AuthorizedClientProvider.class,
73+
ClientCredentialsOAuth2AuthorizedClientProvider.class,
74+
PasswordOAuth2AuthorizedClientProvider.class,
75+
JwtBearerOAuth2AuthorizedClientProvider.class
76+
);
77+
// @formatter:on
78+
79+
private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
80+
81+
private ListableBeanFactory beanFactory;
82+
83+
@Override
84+
public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
85+
if (getBeanNamesForType(OAuth2AuthorizedClientManager.class).length != 0
86+
|| getBeanNamesForType(ClientRegistrationRepository.class).length != 1
87+
|| getBeanNamesForType(OAuth2AuthorizedClientRepository.class).length != 1) {
88+
return;
89+
}
90+
91+
BeanDefinition beanDefinition = BeanDefinitionBuilder
92+
.genericBeanDefinition(OAuth2AuthorizedClientManager.class, this::getAuthorizedClientManager)
93+
.getBeanDefinition();
94+
95+
registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry),
96+
beanDefinition);
97+
}
98+
99+
@Override
100+
public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
101+
}
102+
103+
@Override
104+
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
105+
this.beanFactory = (ListableBeanFactory) beanFactory;
106+
}
107+
108+
private OAuth2AuthorizedClientManager getAuthorizedClientManager() {
109+
ClientRegistrationRepository clientRegistrationRepository = BeanFactoryUtils
110+
.beanOfTypeIncludingAncestors(this.beanFactory, ClientRegistrationRepository.class, true, true);
111+
112+
OAuth2AuthorizedClientRepository authorizedClientRepository = BeanFactoryUtils
113+
.beanOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true);
114+
115+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviderBeans = BeanFactoryUtils
116+
.beansOfTypeIncludingAncestors(this.beanFactory, OAuth2AuthorizedClientProvider.class, true, true)
117+
.values();
118+
119+
OAuth2AuthorizedClientProvider authorizedClientProvider;
120+
if (hasDelegatingAuthorizedClientProvider(authorizedClientProviderBeans)) {
121+
authorizedClientProvider = authorizedClientProviderBeans.iterator().next();
122+
}
123+
else {
124+
List<OAuth2AuthorizedClientProvider> authorizedClientProviders = new ArrayList<>();
125+
authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider(authorizedClientProviderBeans));
126+
authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider(authorizedClientProviderBeans));
127+
authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider(authorizedClientProviderBeans));
128+
authorizedClientProviders.add(getPasswordAuthorizedClientProvider(authorizedClientProviderBeans));
129+
130+
OAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = getJwtBearerAuthorizedClientProvider(
131+
authorizedClientProviderBeans);
132+
if (jwtBearerAuthorizedClientProvider != null) {
133+
authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
134+
}
135+
136+
authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
137+
authorizedClientProvider = new DelegatingOAuth2AuthorizedClientProvider(authorizedClientProviders);
138+
}
139+
140+
DefaultOAuth2AuthorizedClientManager authorizedClientManager = new DefaultOAuth2AuthorizedClientManager(
141+
clientRegistrationRepository, authorizedClientRepository);
142+
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
143+
144+
Consumer<DefaultOAuth2AuthorizedClientManager> authorizedClientManagerConsumer = getBeanOfType(
145+
ResolvableType.forClassWithGenerics(Consumer.class, DefaultOAuth2AuthorizedClientManager.class));
146+
if (authorizedClientManagerConsumer != null) {
147+
authorizedClientManagerConsumer.accept(authorizedClientManager);
148+
}
149+
150+
return authorizedClientManager;
151+
}
152+
153+
private boolean hasDelegatingAuthorizedClientProvider(
154+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
155+
if (authorizedClientProviders.size() != 1) {
156+
return false;
157+
}
158+
return authorizedClientProviders.iterator().next() instanceof DelegatingOAuth2AuthorizedClientProvider;
159+
}
160+
161+
private OAuth2AuthorizedClientProvider getAuthorizationCodeAuthorizedClientProvider(
162+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
163+
AuthorizationCodeOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
164+
authorizedClientProviders, AuthorizationCodeOAuth2AuthorizedClientProvider.class);
165+
if (authorizedClientProvider == null) {
166+
authorizedClientProvider = new AuthorizationCodeOAuth2AuthorizedClientProvider();
167+
}
168+
169+
return authorizedClientProvider;
170+
}
171+
172+
private OAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
173+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
174+
RefreshTokenOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
175+
authorizedClientProviders, RefreshTokenOAuth2AuthorizedClientProvider.class);
176+
if (authorizedClientProvider == null) {
177+
authorizedClientProvider = new RefreshTokenOAuth2AuthorizedClientProvider();
178+
}
179+
180+
OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = getBeanOfType(
181+
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
182+
OAuth2RefreshTokenGrantRequest.class));
183+
if (accessTokenResponseClient != null) {
184+
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
185+
}
186+
187+
return authorizedClientProvider;
188+
}
189+
190+
private OAuth2AuthorizedClientProvider getClientCredentialsAuthorizedClientProvider(
191+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
192+
ClientCredentialsOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
193+
authorizedClientProviders, ClientCredentialsOAuth2AuthorizedClientProvider.class);
194+
if (authorizedClientProvider == null) {
195+
authorizedClientProvider = new ClientCredentialsOAuth2AuthorizedClientProvider();
196+
}
197+
198+
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient = getBeanOfType(
199+
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
200+
OAuth2ClientCredentialsGrantRequest.class));
201+
if (accessTokenResponseClient != null) {
202+
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
203+
}
204+
205+
return authorizedClientProvider;
206+
}
207+
208+
private OAuth2AuthorizedClientProvider getPasswordAuthorizedClientProvider(
209+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
210+
PasswordOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
211+
authorizedClientProviders, PasswordOAuth2AuthorizedClientProvider.class);
212+
if (authorizedClientProvider == null) {
213+
authorizedClientProvider = new PasswordOAuth2AuthorizedClientProvider();
214+
}
215+
216+
OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> accessTokenResponseClient = getBeanOfType(
217+
ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
218+
OAuth2PasswordGrantRequest.class));
219+
if (accessTokenResponseClient != null) {
220+
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
221+
}
222+
223+
return authorizedClientProvider;
224+
}
225+
226+
private OAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider(
227+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
228+
JwtBearerOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
229+
authorizedClientProviders, JwtBearerOAuth2AuthorizedClientProvider.class);
230+
231+
OAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = getBeanOfType(ResolvableType
232+
.forClassWithGenerics(OAuth2AccessTokenResponseClient.class, JwtBearerGrantRequest.class));
233+
if (accessTokenResponseClient != null) {
234+
if (authorizedClientProvider == null) {
235+
authorizedClientProvider = new JwtBearerOAuth2AuthorizedClientProvider();
236+
}
237+
238+
authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
239+
}
240+
241+
return authorizedClientProvider;
242+
}
243+
244+
private List<OAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
245+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders) {
246+
List<OAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(
247+
authorizedClientProviders);
248+
additionalAuthorizedClientProviders
249+
.removeIf((provider) -> KNOWN_AUTHORIZED_CLIENT_PROVIDERS.contains(provider.getClass()));
250+
return additionalAuthorizedClientProviders;
251+
}
252+
253+
private <T extends OAuth2AuthorizedClientProvider> T getAuthorizedClientProviderByType(
254+
Collection<OAuth2AuthorizedClientProvider> authorizedClientProviders, Class<T> providerClass) {
255+
T authorizedClientProvider = null;
256+
for (OAuth2AuthorizedClientProvider current : authorizedClientProviders) {
257+
if (providerClass.isInstance(current)) {
258+
assertAuthorizedClientProviderIsNull(authorizedClientProvider);
259+
authorizedClientProvider = providerClass.cast(current);
260+
}
261+
}
262+
return authorizedClientProvider;
263+
}
264+
265+
private static void assertAuthorizedClientProviderIsNull(OAuth2AuthorizedClientProvider authorizedClientProvider) {
266+
if (authorizedClientProvider != null) {
267+
// @formatter:off
268+
throw new BeanInitializationException(String.format(
269+
"Unable to create an %s bean. Expected one bean of type %s, but found multiple. " +
270+
"Please consider defining only a single bean of this type, or define an %s bean yourself.",
271+
OAuth2AuthorizedClientManager.class.getName(),
272+
authorizedClientProvider.getClass().getName(),
273+
OAuth2AuthorizedClientManager.class.getName()));
274+
// @formatter:on
275+
}
276+
}
277+
278+
private <T> String[] getBeanNamesForType(Class<T> beanClass) {
279+
return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanClass, false, false);
280+
}
281+
282+
private <T> T getBeanOfType(ResolvableType resolvableType) {
283+
ObjectProvider<T> objectProvider = this.beanFactory.getBeanProvider(resolvableType, true);
284+
return objectProvider.getIfAvailable();
285+
}
286+
287+
}

0 commit comments

Comments
 (0)