Skip to content

Commit 023632b

Browse files
vpavickostya05983
authored andcommitted
Allow InMemoryOAuth2AuthorizedClientService to be constructed with a Map
Fixes spring-projectsgh-5994
1 parent 25f82e9 commit 023632b

File tree

5 files changed

+231
-27
lines changed

5 files changed

+231
-27
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@
2020
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
2121
import org.springframework.util.Assert;
2222

23-
import java.util.Base64;
2423
import java.util.Map;
2524
import java.util.concurrent.ConcurrentHashMap;
2625

@@ -29,15 +28,16 @@
2928
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
3029
*
3130
* @author Joe Grandja
31+
* @author Vedran Pavic
3232
* @since 5.0
3333
* @see OAuth2AuthorizedClientService
3434
* @see OAuth2AuthorizedClient
3535
* @see ClientRegistration
3636
* @see Authentication
3737
*/
3838
public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
39-
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
4039
private final ClientRegistrationRepository clientRegistrationRepository;
40+
private Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
4141

4242
/**
4343
* Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters.
@@ -49,23 +49,33 @@ public InMemoryOAuth2AuthorizedClientService(ClientRegistrationRepository client
4949
this.clientRegistrationRepository = clientRegistrationRepository;
5050
}
5151

52+
/**
53+
* Sets the map of authorized clients to use.
54+
* @param authorizedClients the map of authorized clients
55+
*/
56+
public void setAuthorizedClients(Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients) {
57+
Assert.notNull(authorizedClients, "authorizedClients cannot be null");
58+
this.authorizedClients = authorizedClients;
59+
}
60+
5261
@Override
62+
@SuppressWarnings("unchecked")
5363
public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
5464
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
5565
Assert.hasText(principalName, "principalName cannot be empty");
5666
ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
5767
if (registration == null) {
5868
return null;
5969
}
60-
return (T) this.authorizedClients.get(this.getIdentifier(registration, principalName));
70+
return (T) this.authorizedClients.get(OAuth2AuthorizedClientId.create(registration, principalName));
6171
}
6272

6373
@Override
6474
public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
6575
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
6676
Assert.notNull(principal, "principal cannot be null");
67-
this.authorizedClients.put(this.getIdentifier(
68-
authorizedClient.getClientRegistration(), principal.getName()), authorizedClient);
77+
this.authorizedClients.put(OAuth2AuthorizedClientId.create(authorizedClient.getClientRegistration(),
78+
principal.getName()), authorizedClient);
6979
}
7080

7181
@Override
@@ -74,12 +84,8 @@ public void removeAuthorizedClient(String clientRegistrationId, String principal
7484
Assert.hasText(principalName, "principalName cannot be empty");
7585
ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
7686
if (registration != null) {
77-
this.authorizedClients.remove(this.getIdentifier(registration, principalName));
87+
this.authorizedClients.remove(OAuth2AuthorizedClientId.create(registration, principalName));
7888
}
7989
}
8090

81-
private String getIdentifier(ClientRegistration registration, String principalName) {
82-
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
83-
return Base64.getEncoder().encodeToString(identifier.getBytes());
84-
}
8591
}

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.security.oauth2.client;
1717

18-
import java.util.Base64;
1918
import java.util.Map;
2019
import java.util.concurrent.ConcurrentHashMap;
2120

@@ -31,14 +30,15 @@
3130
* {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
3231
*
3332
* @author Rob Winch
33+
* @author Vedran Pavic
3434
* @since 5.1
3535
* @see OAuth2AuthorizedClientService
3636
* @see OAuth2AuthorizedClient
3737
* @see ClientRegistration
3838
* @see Authentication
3939
*/
4040
public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
41-
private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
41+
private final Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();;
4242
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
4343

4444
/**
@@ -52,10 +52,12 @@ public InMemoryReactiveOAuth2AuthorizedClientService(ReactiveClientRegistrationR
5252
}
5353

5454
@Override
55+
@SuppressWarnings("unchecked")
5556
public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String clientRegistrationId, String principalName) {
5657
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
5758
Assert.hasText(principalName, "principalName cannot be empty");
58-
return (Mono<T>) getIdentifier(clientRegistrationId, principalName)
59+
return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
60+
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
5961
.flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
6062
}
6163

@@ -64,7 +66,8 @@ public Mono<Void> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient,
6466
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
6567
Assert.notNull(principal, "principal cannot be null");
6668
return Mono.fromRunnable(() -> {
67-
String identifier = this.getIdentifier(authorizedClient.getClientRegistration(), principal.getName());
69+
OAuth2AuthorizedClientId identifier = OAuth2AuthorizedClientId.create(
70+
authorizedClient.getClientRegistration(), principal.getName());
6871
this.authorizedClients.put(identifier, authorizedClient);
6972
});
7073
}
@@ -73,18 +76,10 @@ public Mono<Void> saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient,
7376
public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) {
7477
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
7578
Assert.hasText(principalName, "principalName cannot be empty");
76-
return this.getIdentifier(clientRegistrationId, principalName)
77-
.doOnNext(identifier -> this.authorizedClients.remove(identifier))
78-
.then(Mono.empty());
79-
}
80-
81-
private Mono<String> getIdentifier(String clientRegistrationId, String principalName) {
8279
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
83-
.map(registration -> getIdentifier(registration, principalName));
80+
.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
81+
.doOnNext(this.authorizedClients::remove)
82+
.then(Mono.empty());
8483
}
8584

86-
private String getIdentifier(ClientRegistration registration, String principalName) {
87-
String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
88-
return Base64.getEncoder().encodeToString(identifier.getBytes());
89-
}
9085
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client;
18+
19+
import java.io.Serializable;
20+
import java.util.Objects;
21+
22+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
23+
import org.springframework.util.Assert;
24+
25+
/**
26+
* The identifier for {@link OAuth2AuthorizedClient}.
27+
*
28+
* @author Vedran Pavic
29+
* @since 5.2
30+
* @see OAuth2AuthorizedClient
31+
* @see OAuth2AuthorizedClientService
32+
*/
33+
public final class OAuth2AuthorizedClientId implements Serializable {
34+
35+
private final String clientRegistrationId;
36+
37+
private final String principalName;
38+
39+
private OAuth2AuthorizedClientId(String clientRegistrationId, String principalName) {
40+
Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null");
41+
Assert.notNull(principalName, "principalName cannot be null");
42+
this.clientRegistrationId = clientRegistrationId;
43+
this.principalName = principalName;
44+
}
45+
46+
/**
47+
* Factory method for creating new {@link OAuth2AuthorizedClientId} using
48+
* {@link ClientRegistration} and principal name.
49+
* @param clientRegistration the client registration
50+
* @param principalName the principal name
51+
* @return the new authorized client id
52+
*/
53+
public static OAuth2AuthorizedClientId create(ClientRegistration clientRegistration,
54+
String principalName) {
55+
return new OAuth2AuthorizedClientId(clientRegistration.getRegistrationId(),
56+
principalName);
57+
}
58+
59+
@Override
60+
public boolean equals(Object obj) {
61+
if (this == obj) {
62+
return true;
63+
}
64+
if (obj == null || getClass() != obj.getClass()) {
65+
return false;
66+
}
67+
OAuth2AuthorizedClientId that = (OAuth2AuthorizedClientId) obj;
68+
return Objects.equals(this.clientRegistrationId, that.clientRegistrationId)
69+
&& Objects.equals(this.principalName, that.principalName);
70+
}
71+
72+
@Override
73+
public int hashCode() {
74+
return Objects.hash(this.clientRegistrationId, this.principalName);
75+
}
76+
77+
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2017 the original author or authors.
2+
* Copyright 2002-2018 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,7 +15,11 @@
1515
*/
1616
package org.springframework.security.oauth2.client;
1717

18+
import java.util.Collections;
19+
import java.util.Map;
20+
1821
import org.junit.Test;
22+
1923
import org.springframework.security.core.Authentication;
2024
import org.springframework.security.oauth2.client.registration.ClientRegistration;
2125
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -24,13 +28,17 @@
2428
import org.springframework.security.oauth2.core.OAuth2AccessToken;
2529

2630
import static org.assertj.core.api.Assertions.assertThat;
31+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
32+
import static org.mockito.ArgumentMatchers.eq;
33+
import static org.mockito.BDDMockito.given;
2734
import static org.mockito.Mockito.mock;
2835
import static org.mockito.Mockito.when;
2936

3037
/**
3138
* Tests for {@link InMemoryOAuth2AuthorizedClientService}.
3239
*
3340
* @author Joe Grandja
41+
* @author Vedran Pavic
3442
*/
3543
public class InMemoryOAuth2AuthorizedClientServiceTests {
3644
private String principalName1 = "principal-1";
@@ -57,6 +65,30 @@ public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArg
5765
new InMemoryOAuth2AuthorizedClientService(null);
5866
}
5967

68+
@Test
69+
public void constructorWhenAuthorizedClientsIsNullThenIllegalArgumentException() {
70+
assertThatExceptionOfType(IllegalArgumentException.class)
71+
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClients(null))
72+
.withMessage("authorizedClients cannot be null");
73+
}
74+
75+
@Test
76+
public void constructorWhenAuthorizedClientsIsEmptyMapThenRepositoryUsingSuppliedAuthorizedClients() {
77+
String registrationId = this.registration3.getRegistrationId();
78+
79+
Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
80+
OAuth2AuthorizedClientId.create(this.registration3, this.principalName1),
81+
mock(OAuth2AuthorizedClient.class));
82+
ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
83+
given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
84+
85+
InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
86+
this.clientRegistrationRepository);
87+
authorizedClientService.setAuthorizedClients(authorizedClients);
88+
assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient(
89+
registrationId, this.principalName1)).isNotNull();
90+
}
91+
6092
@Test(expected = IllegalArgumentException.class)
6193
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
6294
this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.oauth2.client;
18+
19+
import org.junit.Test;
20+
21+
import org.springframework.security.oauth2.client.registration.ClientRegistration;
22+
import org.springframework.security.oauth2.core.AuthorizationGrantType;
23+
24+
import static org.assertj.core.api.Assertions.assertThat;
25+
26+
/**
27+
* Tests for {@link OAuth2AuthorizedClientId}.
28+
*
29+
* @author Vedran Pavic
30+
*/
31+
public class OAuth2AuthorizedClientIdTests {
32+
33+
@Test
34+
public void equalsWhenSameRegistrationIdAndPrincipalThenShouldReturnTrue() {
35+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
36+
"test-principal");
37+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
38+
"test-principal");
39+
assertThat(id1.equals(id2)).isTrue();
40+
}
41+
42+
@Test
43+
public void equalsWhenDifferentRegistrationIdAndSamePrincipalThenShouldReturnFalse() {
44+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
45+
"test-principal");
46+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
47+
"test-principal");
48+
assertThat(id1.equals(id2)).isFalse();
49+
}
50+
51+
@Test
52+
public void equalsWhenSameRegistrationIdAndDifferentPrincipalThenShouldReturnFalse() {
53+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
54+
"test-principal1");
55+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
56+
"test-principal2");
57+
assertThat(id1.equals(id2)).isFalse();
58+
}
59+
60+
@Test
61+
public void hashCodeWhenSameRegistrationIdAndPrincipalThenShouldReturnSame() {
62+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
63+
"test-principal");
64+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
65+
"test-principal");
66+
assertThat(id1.hashCode()).isEqualTo(id2.hashCode());
67+
}
68+
69+
@Test
70+
public void hashCodeWhenDifferentRegistrationIdAndSamePrincipalThenShouldNotReturnSame() {
71+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
72+
"test-principal");
73+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
74+
"test-principal");
75+
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
76+
}
77+
78+
@Test
79+
public void hashCodeWhenSameRegistrationIdAndDifferentPrincipalThenShouldNotReturnSame() {
80+
OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
81+
"test-principal1");
82+
OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
83+
"test-principal2");
84+
assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
85+
}
86+
87+
private static ClientRegistration testClientRegistration(String registrationId) {
88+
return ClientRegistration.withRegistrationId(registrationId).clientId("id").clientSecret("secret")
89+
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
90+
.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
91+
.authorizationUri("http://example.com/authorize").tokenUri("http://example.com/token").build();
92+
}
93+
94+
}

0 commit comments

Comments
 (0)