Skip to content

Commit 6510274

Browse files
committed
Request Cache supports matchingRequestParameterName
Closes gh-7157 gh-11453
1 parent 459003e commit 6510274

File tree

8 files changed

+294
-22
lines changed

8 files changed

+294
-22
lines changed

Diff for: web/src/main/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixin.java

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.security.web.jackson2;
1818

1919
import com.fasterxml.jackson.annotation.JsonAutoDetect;
20+
import com.fasterxml.jackson.annotation.JsonInclude;
2021
import com.fasterxml.jackson.annotation.JsonTypeInfo;
2122
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
2223

@@ -43,4 +44,7 @@
4344
@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE)
4445
abstract class DefaultSavedRequestMixin {
4546

47+
@JsonInclude(JsonInclude.Include.NON_NULL)
48+
String matchingRequestParameterName;
49+
4650
}

Diff for: web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java

+32-2
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,15 @@ public class DefaultSavedRequest implements SavedRequest {
9797

9898
private final int serverPort;
9999

100-
@SuppressWarnings("unchecked")
100+
private final String matchingRequestParameterName;
101+
101102
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver) {
103+
this(request, portResolver, null);
104+
}
105+
106+
@SuppressWarnings("unchecked")
107+
public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver,
108+
String matchingRequestParameterName) {
102109
Assert.notNull(request, "Request required");
103110
Assert.notNull(portResolver, "PortResolver required");
104111
// Cookies
@@ -131,6 +138,7 @@ public DefaultSavedRequest(HttpServletRequest request, PortResolver portResolver
131138
this.serverName = request.getServerName();
132139
this.contextPath = request.getContextPath();
133140
this.servletPath = request.getServletPath();
141+
this.matchingRequestParameterName = matchingRequestParameterName;
134142
}
135143

136144
/**
@@ -147,6 +155,7 @@ private DefaultSavedRequest(Builder builder) {
147155
this.serverName = builder.serverName;
148156
this.servletPath = builder.servletPath;
149157
this.serverPort = builder.serverPort;
158+
this.matchingRequestParameterName = builder.matchingRequestParameterName;
150159
}
151160

152161
/**
@@ -264,8 +273,9 @@ public List<Cookie> getCookies() {
264273
*/
265274
@Override
266275
public String getRedirectUrl() {
276+
String queryString = createQueryString(this.queryString, this.matchingRequestParameterName);
267277
return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI,
268-
this.queryString);
278+
queryString);
269279
}
270280

271281
@Override
@@ -353,6 +363,19 @@ public String toString() {
353363
return "DefaultSavedRequest [" + getRedirectUrl() + "]";
354364
}
355365

366+
private static String createQueryString(String queryString, String matchingRequestParameterName) {
367+
if (matchingRequestParameterName == null) {
368+
return queryString;
369+
}
370+
if (queryString == null || queryString.length() == 0) {
371+
return matchingRequestParameterName;
372+
}
373+
if (queryString.endsWith("&")) {
374+
return queryString + matchingRequestParameterName;
375+
}
376+
return queryString + "&" + matchingRequestParameterName;
377+
}
378+
356379
/**
357380
* @since 4.2
358381
*/
@@ -388,6 +411,8 @@ public static class Builder {
388411

389412
private int serverPort = 80;
390413

414+
private String matchingRequestParameterName;
415+
391416
public Builder setCookies(List<SavedCookie> cookies) {
392417
this.cookies = cookies;
393418
return this;
@@ -458,6 +483,11 @@ public Builder setServerPort(int serverPort) {
458483
return this;
459484
}
460485

486+
public Builder setMatchingRequestParameterName(String matchingRequestParameterName) {
487+
this.matchingRequestParameterName = matchingRequestParameterName;
488+
return this;
489+
}
490+
461491
public DefaultSavedRequest build() {
462492
DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
463493
if (!ObjectUtils.isEmpty(this.cookies)) {

Diff for: web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public class HttpSessionRequestCache implements RequestCache {
5252

5353
private String sessionAttrName = SAVED_REQUEST;
5454

55+
private String matchingRequestParameterName;
56+
5557
/**
5658
* Stores the current request, provided the configuration properties allow it.
5759
*/
@@ -64,7 +66,8 @@ public void saveRequest(HttpServletRequest request, HttpServletResponse response
6466
}
6567
return;
6668
}
67-
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver);
69+
DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, this.portResolver,
70+
this.matchingRequestParameterName);
6871
if (this.createSessionAllowed || request.getSession(false) != null) {
6972
// Store the HTTP request itself. Used by
7073
// AbstractAuthenticationProcessingFilter
@@ -96,6 +99,12 @@ public void removeRequest(HttpServletRequest currentRequest, HttpServletResponse
9699

97100
@Override
98101
public HttpServletRequest getMatchingRequest(HttpServletRequest request, HttpServletResponse response) {
102+
if (this.matchingRequestParameterName != null
103+
&& request.getParameter(this.matchingRequestParameterName) == null) {
104+
this.logger.trace(
105+
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
106+
return null;
107+
}
99108
SavedRequest saved = getRequest(request, response);
100109
if (saved == null) {
101110
this.logger.trace("No saved request");
@@ -161,4 +170,16 @@ public void setSessionAttrName(String sessionAttrName) {
161170
this.sessionAttrName = sessionAttrName;
162171
}
163172

173+
/**
174+
* Specify the name of a query parameter that is added to the URL that specifies the
175+
* request cache should be checked in
176+
* {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)}
177+
* @param matchingRequestParameterName the parameter name that must be in the request
178+
* for {@link #getMatchingRequest(HttpServletRequest, HttpServletResponse)} to check
179+
* the session.
180+
*/
181+
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
182+
this.matchingRequestParameterName = matchingRequestParameterName;
183+
}
184+
164185
}

Diff for: web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java

+66-3
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
3535
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
3636
import org.springframework.util.Assert;
37+
import org.springframework.util.MultiValueMap;
3738
import org.springframework.web.server.ServerWebExchange;
3839
import org.springframework.web.server.WebSession;
40+
import org.springframework.web.util.UriComponentsBuilder;
3941

4042
/**
4143
* An implementation of {@link ServerRequestCache} that saves the
@@ -57,6 +59,8 @@ public class WebSessionServerRequestCache implements ServerRequestCache {
5759

5860
private ServerWebExchangeMatcher saveRequestMatcher = createDefaultRequestMacher();
5961

62+
private String matchingRequestParameterName;
63+
6064
/**
6165
* Sets the matcher to determine if the request should be saved. The default is to
6266
* match on any GET request.
@@ -81,19 +85,53 @@ public Mono<Void> saveRequest(ServerWebExchange exchange) {
8185
public Mono<URI> getRedirectUri(ServerWebExchange exchange) {
8286
return exchange.getSession()
8387
.flatMap((session) -> Mono.justOrEmpty(session.<String>getAttribute(this.sessionAttrName)))
84-
.map(URI::create);
88+
.map(this::createRedirectUri);
8589
}
8690

8791
@Override
8892
public Mono<ServerHttpRequest> removeMatchingRequest(ServerWebExchange exchange) {
93+
MultiValueMap<String, String> queryParams = exchange.getRequest().getQueryParams();
94+
if (this.matchingRequestParameterName != null && !queryParams.containsKey(this.matchingRequestParameterName)) {
95+
this.logger.trace(
96+
"matchingRequestParameterName is required for getMatchingRequest to lookup a value, but not provided");
97+
return Mono.empty();
98+
}
99+
ServerHttpRequest request = stripMatchingRequestParameterName(exchange.getRequest());
89100
return exchange.getSession().map(WebSession::getAttributes).filter((attributes) -> {
90-
String requestPath = pathInApplication(exchange.getRequest());
101+
String requestPath = pathInApplication(request);
91102
boolean removed = attributes.remove(this.sessionAttrName, requestPath);
92103
if (removed) {
93104
logger.debug(LogMessage.format("Request removed from WebSession: '%s'", requestPath));
94105
}
95106
return removed;
96-
}).map((attributes) -> exchange.getRequest());
107+
}).map((attributes) -> request);
108+
}
109+
110+
/**
111+
* Specify the name of a query parameter that is added to the URL in
112+
* {@link #getRedirectUri(ServerWebExchange)} and is required for
113+
* {@link #removeMatchingRequest(ServerWebExchange)} to look up the
114+
* {@link ServerHttpRequest}.
115+
* @param matchingRequestParameterName the parameter name that must be in the request
116+
* for {@link #removeMatchingRequest(ServerWebExchange)} to check the session.
117+
*/
118+
public void setMatchingRequestParameterName(String matchingRequestParameterName) {
119+
this.matchingRequestParameterName = matchingRequestParameterName;
120+
}
121+
122+
private ServerHttpRequest stripMatchingRequestParameterName(ServerHttpRequest request) {
123+
if (this.matchingRequestParameterName == null) {
124+
return request;
125+
}
126+
// @formatter:off
127+
URI uri = UriComponentsBuilder.fromUri(request.getURI())
128+
.replaceQueryParam(this.matchingRequestParameterName)
129+
.build()
130+
.toUri();
131+
return request.mutate()
132+
.uri(uri)
133+
.build();
134+
// @formatter:on
97135
}
98136

99137
private static String pathInApplication(ServerHttpRequest request) {
@@ -102,6 +140,18 @@ private static String pathInApplication(ServerHttpRequest request) {
102140
return path + ((query != null) ? "?" + query : "");
103141
}
104142

143+
private URI createRedirectUri(String uri) {
144+
if (this.matchingRequestParameterName == null) {
145+
return URI.create(uri);
146+
}
147+
// @formatter:off
148+
return UriComponentsBuilder.fromUriString(uri)
149+
.queryParam(this.matchingRequestParameterName)
150+
.build()
151+
.toUri();
152+
// @formatter:on
153+
}
154+
105155
private static ServerWebExchangeMatcher createDefaultRequestMacher() {
106156
ServerWebExchangeMatcher get = ServerWebExchangeMatchers.pathMatchers(HttpMethod.GET, "/**");
107157
ServerWebExchangeMatcher notFavicon = new NegatedServerWebExchangeMatcher(
@@ -111,4 +161,17 @@ private static ServerWebExchangeMatcher createDefaultRequestMacher() {
111161
return new AndServerWebExchangeMatcher(get, notFavicon, html);
112162
}
113163

164+
private static String createQueryString(String queryString, String matchingRequestParameterName) {
165+
if (matchingRequestParameterName == null) {
166+
return queryString;
167+
}
168+
if (queryString == null || queryString.length() == 0) {
169+
return matchingRequestParameterName;
170+
}
171+
if (queryString.endsWith("&")) {
172+
return queryString + matchingRequestParameterName;
173+
}
174+
return queryString + "&" + matchingRequestParameterName;
175+
}
176+
114177
}

Diff for: web/src/test/java/org/springframework/security/web/jackson2/DefaultSavedRequestMixinTests.java

+49-16
Original file line numberDiff line numberDiff line change
@@ -55,22 +55,42 @@ public class DefaultSavedRequestMixinTests extends AbstractMixinTests {
5555
// @formatter:on
5656
// @formatter:off
5757
private static final String REQUEST_JSON = "{" +
58-
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
59-
+ "\"cookies\": " + COOKIES_JSON + ","
60-
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
61-
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
62-
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
63-
+ "\"contextPath\": \"\", "
64-
+ "\"method\": \"\", "
65-
+ "\"pathInfo\": null, "
66-
+ "\"queryString\": null, "
67-
+ "\"requestURI\": \"\", "
68-
+ "\"requestURL\": \"http://localhost\", "
69-
+ "\"scheme\": \"http\", "
70-
+ "\"serverName\": \"localhost\", "
71-
+ "\"servletPath\": \"\", "
72-
+ "\"serverPort\": 80"
73-
+ "}";
58+
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
59+
+ "\"cookies\": " + COOKIES_JSON + ","
60+
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
61+
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
62+
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
63+
+ "\"contextPath\": \"\", "
64+
+ "\"method\": \"\", "
65+
+ "\"pathInfo\": null, "
66+
+ "\"queryString\": null, "
67+
+ "\"requestURI\": \"\", "
68+
+ "\"requestURL\": \"http://localhost\", "
69+
+ "\"scheme\": \"http\", "
70+
+ "\"serverName\": \"localhost\", "
71+
+ "\"servletPath\": \"\", "
72+
+ "\"serverPort\": 80"
73+
+ "}";
74+
// @formatter:on
75+
// @formatter:off
76+
private static final String REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON = "{" +
77+
"\"@class\": \"org.springframework.security.web.savedrequest.DefaultSavedRequest\", "
78+
+ "\"cookies\": " + COOKIES_JSON + ","
79+
+ "\"locales\": [\"java.util.ArrayList\", [\"en\"]], "
80+
+ "\"headers\": {\"@class\": \"java.util.TreeMap\", \"x-auth-token\": [\"java.util.ArrayList\", [\"12\"]]}, "
81+
+ "\"parameters\": {\"@class\": \"java.util.TreeMap\"},"
82+
+ "\"contextPath\": \"\", "
83+
+ "\"method\": \"\", "
84+
+ "\"pathInfo\": null, "
85+
+ "\"queryString\": null, "
86+
+ "\"requestURI\": \"\", "
87+
+ "\"requestURL\": \"http://localhost\", "
88+
+ "\"scheme\": \"http\", "
89+
+ "\"serverName\": \"localhost\", "
90+
+ "\"servletPath\": \"\", "
91+
+ "\"serverPort\": 80, "
92+
+ "\"matchingRequestParameterName\": \"success\""
93+
+ "}";
7494
// @formatter:on
7595
@Test
7696
public void matchRequestBuildWithConstructorAndBuilder() {
@@ -125,4 +145,17 @@ public void deserializeDefaultSavedRequest() throws IOException {
125145
assertThat(request.getHeaderValues("x-auth-token")).hasSize(1).contains("12");
126146
}
127147

148+
@Test
149+
public void deserializeWhenMatchingRequestParameterNameThenRedirectUrlContainsParam() throws IOException {
150+
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper
151+
.readValue(REQUEST_WITH_MATCHING_REQUEST_PARAM_NAME_JSON, Object.class);
152+
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost?success");
153+
}
154+
155+
@Test
156+
public void deserializeWhenNullMatchingRequestParameterNameThenRedirectUrlDoesNotContainParam() throws IOException {
157+
DefaultSavedRequest request = (DefaultSavedRequest) this.mapper.readValue(REQUEST_JSON, Object.class);
158+
assertThat(request.getRedirectUrl()).isEqualTo("http://localhost");
159+
}
160+
128161
}

0 commit comments

Comments
 (0)