Skip to content

Commit acc35ae

Browse files
author
Steve Riesenberg
committedOct 18, 2022
Add DelegatingSecurityContextRepository
Issue gh-12023
1 parent c75ca10 commit acc35ae

File tree

2 files changed

+255
-0
lines changed

2 files changed

+255
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.context;
18+
19+
import java.util.Arrays;
20+
import java.util.List;
21+
22+
import javax.servlet.http.HttpServletRequest;
23+
import javax.servlet.http.HttpServletResponse;
24+
25+
import org.springframework.security.core.context.DeferredSecurityContext;
26+
import org.springframework.security.core.context.SecurityContext;
27+
import org.springframework.util.Assert;
28+
29+
/**
30+
* @author Steve Riesenberg
31+
* @author Josh Cummings
32+
* @since 5.8
33+
*/
34+
public final class DelegatingSecurityContextRepository implements SecurityContextRepository {
35+
36+
private final List<SecurityContextRepository> delegates;
37+
38+
public DelegatingSecurityContextRepository(SecurityContextRepository... delegates) {
39+
this(Arrays.asList(delegates));
40+
}
41+
42+
public DelegatingSecurityContextRepository(List<SecurityContextRepository> delegates) {
43+
Assert.notEmpty(delegates, "delegates cannot be empty");
44+
this.delegates = delegates;
45+
}
46+
47+
@Override
48+
public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
49+
return loadContext(requestResponseHolder.getRequest()).get();
50+
}
51+
52+
@Override
53+
public DeferredSecurityContext loadDeferredContext(HttpServletRequest request) {
54+
DeferredSecurityContext deferredSecurityContext = null;
55+
for (SecurityContextRepository delegate : this.delegates) {
56+
if (deferredSecurityContext == null) {
57+
deferredSecurityContext = delegate.loadDeferredContext(request);
58+
}
59+
else {
60+
DeferredSecurityContext next = delegate.loadDeferredContext(request);
61+
deferredSecurityContext = new DelegatingDeferredSecurityContext(deferredSecurityContext, next);
62+
}
63+
}
64+
return deferredSecurityContext;
65+
}
66+
67+
@Override
68+
public void saveContext(SecurityContext context, HttpServletRequest request, HttpServletResponse response) {
69+
for (SecurityContextRepository delegate : this.delegates) {
70+
delegate.saveContext(context, request, response);
71+
}
72+
}
73+
74+
@Override
75+
public boolean containsContext(HttpServletRequest request) {
76+
for (SecurityContextRepository delegate : this.delegates) {
77+
if (delegate.containsContext(request)) {
78+
return true;
79+
}
80+
}
81+
return false;
82+
}
83+
84+
static final class DelegatingDeferredSecurityContext implements DeferredSecurityContext {
85+
86+
private final DeferredSecurityContext previous;
87+
88+
private final DeferredSecurityContext next;
89+
90+
DelegatingDeferredSecurityContext(DeferredSecurityContext previous, DeferredSecurityContext next) {
91+
this.previous = previous;
92+
this.next = next;
93+
}
94+
95+
@Override
96+
public SecurityContext get() {
97+
SecurityContext securityContext = this.previous.get();
98+
if (!this.previous.isGenerated()) {
99+
return securityContext;
100+
}
101+
return this.next.get();
102+
}
103+
104+
@Override
105+
public boolean isGenerated() {
106+
return this.previous.isGenerated() && this.next.isGenerated();
107+
}
108+
109+
}
110+
111+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.security.web.context;
18+
19+
import java.util.ArrayList;
20+
import java.util.List;
21+
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.CsvSource;
26+
27+
import org.springframework.mock.web.MockHttpServletRequest;
28+
import org.springframework.mock.web.MockHttpServletResponse;
29+
import org.springframework.security.authentication.TestingAuthenticationToken;
30+
import org.springframework.security.core.context.DeferredSecurityContext;
31+
import org.springframework.security.core.context.SecurityContext;
32+
import org.springframework.security.core.context.SecurityContextHolderStrategy;
33+
import org.springframework.security.core.context.SecurityContextImpl;
34+
35+
import static org.assertj.core.api.Assertions.assertThat;
36+
import static org.mockito.BDDMockito.given;
37+
import static org.mockito.Mockito.mock;
38+
import static org.mockito.Mockito.verify;
39+
import static org.mockito.Mockito.verifyNoInteractions;
40+
import static org.mockito.Mockito.verifyNoMoreInteractions;
41+
42+
/**
43+
* Tests for {@link DelegatingSecurityContextRepository}.
44+
*
45+
* @author Steve Riesenberg
46+
* @since 5.8
47+
*/
48+
public class DelegatingSecurityContextRepositoryTests {
49+
50+
private MockHttpServletRequest request;
51+
52+
private MockHttpServletResponse response;
53+
54+
private SecurityContextHolderStrategy strategy;
55+
56+
private SecurityContext securityContext;
57+
58+
@BeforeEach
59+
public void setUp() {
60+
this.request = new MockHttpServletRequest();
61+
this.response = new MockHttpServletResponse();
62+
this.strategy = mock(SecurityContextHolderStrategy.class);
63+
this.securityContext = mock(SecurityContext.class);
64+
}
65+
66+
@ParameterizedTest
67+
@CsvSource({ "0,false", "1,false", "2,false", "-1,true" })
68+
public void loadDeferredContextWhenIsGeneratedThenReturnsSecurityContext(int expectedIndex, boolean isGenerated) {
69+
SecurityContext actualSecurityContext = new SecurityContextImpl(
70+
new TestingAuthenticationToken("user", "password"));
71+
SecurityContext emptySecurityContext = new SecurityContextImpl();
72+
given(this.strategy.createEmptyContext()).willReturn(emptySecurityContext);
73+
List<SecurityContextRepository> delegates = new ArrayList<>();
74+
for (int i = 0; i < 3; i++) {
75+
SecurityContext context = (i == expectedIndex) ? actualSecurityContext : null;
76+
SecurityContextRepository repository = mock(SecurityContextRepository.class);
77+
SupplierDeferredSecurityContext supplier = new SupplierDeferredSecurityContext(() -> context,
78+
this.strategy);
79+
given(repository.loadDeferredContext(this.request)).willReturn(supplier);
80+
delegates.add(repository);
81+
}
82+
83+
DelegatingSecurityContextRepository repository = new DelegatingSecurityContextRepository(delegates);
84+
DeferredSecurityContext deferredSecurityContext = repository.loadDeferredContext(this.request);
85+
SecurityContext expectedSecurityContext = (isGenerated) ? emptySecurityContext : actualSecurityContext;
86+
assertThat(deferredSecurityContext.get()).isEqualTo(expectedSecurityContext);
87+
assertThat(deferredSecurityContext.isGenerated()).isEqualTo(isGenerated);
88+
89+
for (SecurityContextRepository delegate : delegates) {
90+
verify(delegate).loadDeferredContext(this.request);
91+
verifyNoMoreInteractions(delegate);
92+
}
93+
}
94+
95+
@Test
96+
public void saveContextAlwaysCallsDelegates() {
97+
List<SecurityContextRepository> delegates = new ArrayList<>();
98+
for (int i = 0; i < 3; i++) {
99+
SecurityContextRepository repository = mock(SecurityContextRepository.class);
100+
delegates.add(repository);
101+
}
102+
103+
DelegatingSecurityContextRepository repository = new DelegatingSecurityContextRepository(delegates);
104+
repository.saveContext(this.securityContext, this.request, this.response);
105+
for (SecurityContextRepository delegate : delegates) {
106+
verify(delegate).saveContext(this.securityContext, this.request, this.response);
107+
verifyNoMoreInteractions(delegate);
108+
}
109+
}
110+
111+
@Test
112+
public void containsContextWhenAllDelegatesReturnFalseThenReturnsFalse() {
113+
List<SecurityContextRepository> delegates = new ArrayList<>();
114+
for (int i = 0; i < 3; i++) {
115+
SecurityContextRepository repository = mock(SecurityContextRepository.class);
116+
given(repository.containsContext(this.request)).willReturn(false);
117+
delegates.add(repository);
118+
}
119+
120+
DelegatingSecurityContextRepository repository = new DelegatingSecurityContextRepository(delegates);
121+
assertThat(repository.containsContext(this.request)).isFalse();
122+
for (SecurityContextRepository delegate : delegates) {
123+
verify(delegate).containsContext(this.request);
124+
verifyNoMoreInteractions(delegate);
125+
}
126+
}
127+
128+
@Test
129+
public void containsContextWhenFirstDelegatesReturnTrueThenReturnsTrue() {
130+
List<SecurityContextRepository> delegates = new ArrayList<>();
131+
for (int i = 0; i < 3; i++) {
132+
SecurityContextRepository repository = mock(SecurityContextRepository.class);
133+
given(repository.containsContext(this.request)).willReturn(true);
134+
delegates.add(repository);
135+
}
136+
137+
DelegatingSecurityContextRepository repository = new DelegatingSecurityContextRepository(delegates);
138+
assertThat(repository.containsContext(this.request)).isTrue();
139+
verify(delegates.get(0)).containsContext(this.request);
140+
verifyNoInteractions(delegates.get(1));
141+
verifyNoInteractions(delegates.get(2));
142+
}
143+
144+
}

0 commit comments

Comments
 (0)
Please sign in to comment.