Skip to content

Support nested username attribute in DefaultOAuth2User #14265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,9 +16,9 @@

package org.springframework.security.oauth2.client.userinfo;

import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
Expand Down Expand Up @@ -76,6 +76,9 @@ public class DefaultOAuth2UserService implements OAuth2UserService<OAuth2UserReq

private Converter<OAuth2UserRequest, RequestEntity<?>> requestEntityConverter = new OAuth2UserRequestEntityConverter();

private Converter<OAuth2UserRequest, Converter<Map<String, Object>, Map<String, Object>>> attributesConverter = (
request) -> (attributes) -> attributes;

private RestOperations restOperations;

public DefaultOAuth2UserService() {
Expand All @@ -87,35 +90,39 @@ public DefaultOAuth2UserService() {
@Override
public OAuth2User loadUser(OAuth2UserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null");
if (!StringUtils
.hasText(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri())) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_INFO_URI_ERROR_CODE,
"Missing required UserInfo Uri in UserInfoEndpoint for Client Registration: "
+ userRequest.getClientRegistration().getRegistrationId(),
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = userRequest.getClientRegistration()
.getProviderDetails()
.getUserInfoEndpoint()
.getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
"Missing required \"user name\" attribute name in UserInfoEndpoint for Client Registration: "
+ userRequest.getClientRegistration().getRegistrationId(),
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = getUserNameAttributeName(userRequest);
RequestEntity<?> request = this.requestEntityConverter.convert(userRequest);
ResponseEntity<Map<String, Object>> response = getResponse(userRequest, request);
Map<String, Object> userAttributes = response.getBody();
Set<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(userAttributes));
OAuth2AccessToken token = userRequest.getAccessToken();
for (String authority : token.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return new DefaultOAuth2User(authorities, userAttributes, userNameAttributeName);
Map<String, Object> attributes = this.attributesConverter.convert(userRequest).convert(response.getBody());
Collection<GrantedAuthority> authorities = getAuthorities(token, attributes);
return new DefaultOAuth2User(authorities, attributes, userNameAttributeName);
}

/**
* Use this strategy to adapt user attributes into a format understood by Spring
* Security; by default, the original attributes are preserved.
*
* <p>
* This can be helpful, for example, if the user attribute is nested. Since Spring
* Security needs the username attribute to be at the top level, you can use this
* method to do:
*
* <pre>
* DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
* userService.setAttributesConverter((userRequest) -> (attributes) ->
* Map&lt;String, Object&gt; userObject = (Map&lt;String, Object&gt;) attributes.get("user");
* attributes.put("user-name", userObject.get("user-name"));
* return attributes;
* });
* </pre>
* @param attributesConverter the attribute adaptation strategy to use
* @since 6.3
*/
public void setAttributesConverter(
Converter<OAuth2UserRequest, Converter<Map<String, Object>, Map<String, Object>>> attributesConverter) {
Assert.notNull(attributesConverter, "attributesConverter cannot be null");
this.attributesConverter = attributesConverter;
}

private ResponseEntity<Map<String, Object>> getResponse(OAuth2UserRequest userRequest, RequestEntity<?> request) {
Expand Down Expand Up @@ -157,6 +164,38 @@ private ResponseEntity<Map<String, Object>> getResponse(OAuth2UserRequest userRe
}
}

private String getUserNameAttributeName(OAuth2UserRequest userRequest) {
if (!StringUtils
.hasText(userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint().getUri())) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_INFO_URI_ERROR_CODE,
"Missing required UserInfo Uri in UserInfoEndpoint for Client Registration: "
+ userRequest.getClientRegistration().getRegistrationId(),
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
String userNameAttributeName = userRequest.getClientRegistration()
.getProviderDetails()
.getUserInfoEndpoint()
.getUserNameAttributeName();
if (!StringUtils.hasText(userNameAttributeName)) {
OAuth2Error oauth2Error = new OAuth2Error(MISSING_USER_NAME_ATTRIBUTE_ERROR_CODE,
"Missing required \"user name\" attribute name in UserInfoEndpoint for Client Registration: "
+ userRequest.getClientRegistration().getRegistrationId(),
null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
}
return userNameAttributeName;
}

private Collection<GrantedAuthority> getAuthorities(OAuth2AccessToken token, Map<String, Object> attributes) {
Collection<GrantedAuthority> authorities = new LinkedHashSet<>();
authorities.add(new OAuth2UserAuthority(attributes));
for (String authority : token.getScopes()) {
authorities.add(new SimpleGrantedAuthority("SCOPE_" + authority));
}
return authorities;
}

/**
* Sets the {@link Converter} used for converting the {@link OAuth2UserRequest} to a
* {@link RequestEntity} representation of the UserInfo Request.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
import reactor.core.publisher.Mono;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
Expand Down Expand Up @@ -78,6 +79,9 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi
private static final ParameterizedTypeReference<Map<String, String>> STRING_STRING_MAP = new ParameterizedTypeReference<Map<String, String>>() {
};

private Converter<OAuth2UserRequest, Converter<Map<String, Object>, Map<String, Object>>> attributesConverter = (
request) -> (attributes) -> attributes;

private WebClient webClient = WebClient.create();

@Override
Expand Down Expand Up @@ -123,7 +127,8 @@ public Mono<OAuth2User> loadUser(OAuth2UserRequest userRequest) throws OAuth2Aut
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
})
)
.bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP);
.bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP)
.mapNotNull((attributes) -> this.attributesConverter.convert(userRequest).convert(attributes));
return userAttributes.map((attrs) -> {
GrantedAuthority authority = new OAuth2UserAuthority(attrs);
Set<GrantedAuthority> authorities = new HashSet<>();
Expand Down Expand Up @@ -184,6 +189,32 @@ private WebClient.RequestHeadersSpec<?> getRequestHeaderSpec(OAuth2UserRequest u
// @formatter:on
}

/**
* Use this strategy to adapt user attributes into a format understood by Spring
* Security; by default, the original attributes are preserved.
*
* <p>
* This can be helpful, for example, if the user attribute is nested. Since Spring
* Security needs the username attribute to be at the top level, you can use this
* method to do:
*
* <pre>
* DefaultReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
* userService.setAttributesConverter((userRequest) -> (attributes) ->
* Map&lt;String, Object&gt; userObject = (Map&lt;String, Object&gt;) attributes.get("user");
* attributes.put("user-name", userObject.get("user-name"));
* return attributes;
* });
* </pre>
* @param attributesConverter the attribute adaptation strategy to use
* @since 6.3
*/
public void setAttributesConverter(
Converter<OAuth2UserRequest, Converter<Map<String, Object>, Map<String, Object>>> attributesConverter) {
Assert.notNull(attributesConverter, "attributesConverter cannot be null");
this.attributesConverter = attributesConverter;
}

/**
* Sets the {@link WebClient} used for retrieving the user endpoint
* @param webClient the client to use
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

package org.springframework.security.oauth2.client.oidc.userinfo;

import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
Expand All @@ -24,6 +25,8 @@
import java.util.Map;
import java.util.function.Function;

import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -32,13 +35,17 @@
import reactor.core.publisher.Mono;

import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.userinfo.DefaultReactiveOAuth2UserService;
import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
import org.springframework.security.oauth2.core.AuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
Expand Down Expand Up @@ -203,8 +210,62 @@ public void loadUserWhenTokenDoesNotContainScopesThenNoScopeAuthorities() {
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
}

@Test
public void loadUserWhenNestedUserInfoSuccessThenReturnUser() throws IOException {
// @formatter:off
String userInfoResponse = "{\n"
+ " \"user\": {\"user-name\": \"user1\"},\n"
+ " \"sub\" : \"" + this.idToken.getSubject() + "\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"[email protected]\"\n"
+ "}\n";
// @formatter:on
try (MockWebServer server = new MockWebServer()) {
server.start();
enqueueApplicationJsonBody(server, userInfoResponse);
String userInfoUri = server.url("/user").toString();
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name")
.build();
OidcReactiveOAuth2UserService userService = new OidcReactiveOAuth2UserService();
DefaultReactiveOAuth2UserService oAuth2UserService = new DefaultReactiveOAuth2UserService();
oAuth2UserService.setAttributesConverter((request) -> (attributes) -> {
Map<String, Object> user = (Map<String, Object>) attributes.get("user");
attributes.put("user-name", user.get("user-name"));
return attributes;
});
userService.setOauth2UserService(oAuth2UserService);
OAuth2User user = userService
.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken))
.block();
assertThat(user.getName()).isEqualTo("user1");
assertThat(user.getAttributes()).hasSize(13);
assertThat(((Map<?, ?>) user.getAttribute("user")).get("user-name")).isEqualTo("user1");
assertThat((String) user.getAttribute("first-name")).isEqualTo("first");
assertThat((String) user.getAttribute("last-name")).isEqualTo("last");
assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle");
assertThat((String) user.getAttribute("address")).isEqualTo("address");
assertThat((String) user.getAttribute("email")).isEqualTo("[email protected]");
assertThat(user.getAuthorities()).hasSize(2);
assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class);
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
}
}

private OidcUserRequest userRequest() {
return new OidcUserRequest(this.registration.build(), this.accessToken, this.idToken);
}

private void enqueueApplicationJsonBody(MockWebServer server, String json) {
server.enqueue(
new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json));
}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2022 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -52,6 +52,8 @@
import org.springframework.security.oauth2.core.oidc.TestOidcIdTokens;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.oidc.user.OidcUserAuthority;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
Expand Down Expand Up @@ -492,6 +494,49 @@ public void loadUserWhenTokenDoesNotContainScopesAndUserInfoUriThenUserInfoReque
assertThat(user.getUserInfo()).isNotNull();
}

@Test
public void loadUserWhenNestedUserInfoSuccessThenReturnUser() {
// @formatter:off
String userInfoResponse = "{\n"
+ " \"user\": {\"user-name\": \"user1\"},\n"
+ " \"sub\" : \"subject1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"[email protected]\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.userNameAttributeName("user-name")
.build();
OidcUserService userService = new OidcUserService();
DefaultOAuth2UserService oAuth2UserService = new DefaultOAuth2UserService();
oAuth2UserService.setAttributesConverter((request) -> (attributes) -> {
Map<String, Object> user = (Map<String, Object>) attributes.get("user");
attributes.put("user-name", user.get("user-name"));
return attributes;
});
userService.setOauth2UserService(oAuth2UserService);
OAuth2User user = userService.loadUser(new OidcUserRequest(clientRegistration, this.accessToken, this.idToken));
assertThat(user.getName()).isEqualTo("user1");
assertThat(user.getAttributes()).hasSize(9);
assertThat(((Map<?, ?>) user.getAttribute("user")).get("user-name")).isEqualTo("user1");
assertThat((String) user.getAttribute("first-name")).isEqualTo("first");
assertThat((String) user.getAttribute("last-name")).isEqualTo("last");
assertThat((String) user.getAttribute("middle-name")).isEqualTo("middle");
assertThat((String) user.getAttribute("address")).isEqualTo("address");
assertThat((String) user.getAttribute("email")).isEqualTo("[email protected]");
assertThat(user.getAuthorities()).hasSize(3);
assertThat(user.getAuthorities().iterator().next()).isInstanceOf(OAuth2UserAuthority.class);
OAuth2UserAuthority userAuthority = (OAuth2UserAuthority) user.getAuthorities().iterator().next();
assertThat(userAuthority.getAuthority()).isEqualTo("OIDC_USER");
assertThat(userAuthority.getAttributes()).isEqualTo(user.getAttributes());
}

private MockResponse jsonResponse(String json) {
// @formatter:off
return new MockResponse()
Expand Down
Loading