Skip to content

Commit c6f1b24

Browse files
committed
Saml2AuthenticationToken Takes RelyingPartyRegistration
Now that the credentials are split, it's a bit cleaner to keep all the details with RelyingPartyRegistration instead of trying to split them out into different contructor parameters Issue spring-projectsgh-8777
1 parent 4713aec commit c6f1b24

File tree

4 files changed

+129
-73
lines changed

4 files changed

+129
-73
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
*/
1616
package org.springframework.security.saml2.provider.service.authentication;
1717

18-
import java.security.cert.X509Certificate;
1918
import java.time.Duration;
2019
import java.time.Instant;
2120
import java.util.ArrayList;
@@ -24,7 +23,6 @@
2423
import java.util.HashMap;
2524
import java.util.HashSet;
2625
import java.util.LinkedHashMap;
27-
import java.util.LinkedList;
2826
import java.util.List;
2927
import java.util.Map;
3028
import java.util.Set;
@@ -38,7 +36,6 @@
3836
import org.opensaml.core.xml.XMLObject;
3937
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
4038
import org.opensaml.core.xml.io.Marshaller;
41-
4239
import org.opensaml.core.xml.schema.XSAny;
4340
import org.opensaml.core.xml.schema.XSBoolean;
4441
import org.opensaml.core.xml.schema.XSBooleanValue;
@@ -95,6 +92,7 @@
9592
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
9693
import org.springframework.security.saml2.Saml2Exception;
9794
import org.springframework.security.saml2.credentials.Saml2X509Credential;
95+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
9896
import org.springframework.util.Assert;
9997
import org.springframework.util.StringUtils;
10098

@@ -218,9 +216,9 @@ public Authentication authenticate(Authentication authentication) throws Authent
218216
try {
219217
Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication;
220218
Response response = parse(token.getSaml2Response());
221-
List<Assertion> validAssertions = validateResponse(token, response);
219+
List<Assertion> validAssertions = validateResponse(token.getRelyingPartyRegistration(), response);
222220
Assertion assertion = validAssertions.get(0);
223-
String username = getUsername(token, assertion);
221+
String username = getUsername(token.getRelyingPartyRegistration(), assertion);
224222
Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
225223
return new Saml2Authentication(
226224
new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
@@ -259,7 +257,7 @@ private Response parse(String response) throws Saml2Exception, Saml2Authenticati
259257

260258
}
261259

262-
private List<Assertion> validateResponse(Saml2AuthenticationToken token, Response response)
260+
private List<Assertion> validateResponse(RelyingPartyRegistration relyingParty, Response response)
263261
throws Saml2AuthenticationException {
264262

265263
List<Assertion> validAssertions = new ArrayList<>();
@@ -270,7 +268,7 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
270268

271269
List<Assertion> assertions = new ArrayList<>(response.getAssertions());
272270
for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
273-
Assertion assertion = decrypt(token, encryptedAssertion);
271+
Assertion assertion = decrypt(relyingParty, encryptedAssertion);
274272
assertions.add(assertion);
275273
}
276274
if (assertions.isEmpty()) {
@@ -282,7 +280,7 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
282280
"Please either sign the response or all of the assertions.");
283281
}
284282

285-
SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine(token);
283+
SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine(relyingParty);
286284

287285
Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
288286
if (response.isSigned()) {
@@ -310,18 +308,18 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
310308
}
311309

312310
String destination = response.getDestination();
313-
if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) {
311+
if (StringUtils.hasText(destination) && !destination.equals(relyingParty.getAssertionConsumerServiceLocation())) {
314312
String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
315313
validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
316314
}
317315

318-
if (!StringUtils.hasText(issuer) || !issuer.equals(token.getIdpEntityId())) {
316+
if (!StringUtils.hasText(issuer) || !issuer.equals(relyingParty.getProviderDetails().getEntityId())) {
319317
String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
320318
validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
321319
}
322320

323321
SAML20AssertionValidator validator = buildSamlAssertionValidator(signatureTrustEngine);
324-
ValidationContext context = buildValidationContext(token, response);
322+
ValidationContext context = buildValidationContext(relyingParty, response);
325323

326324
if (logger.isDebugEnabled()) {
327325
logger.debug("Validating " + assertions.size() + " assertions");
@@ -375,12 +373,12 @@ private boolean isSigned(Response samlResponse, List<Assertion> assertions) {
375373
return true;
376374
}
377375

378-
private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
376+
private SignatureTrustEngine buildSignatureTrustEngine(RelyingPartyRegistration relyingParty) {
379377
Set<Credential> credentials = new HashSet<>();
380-
for (X509Certificate key : getVerificationCertificates(token)) {
381-
BasicX509Credential cred = new BasicX509Credential(key);
378+
for (Saml2X509Credential credential : relyingParty.getProviderDetails().getVerificationX509Credentials()) {
379+
BasicX509Credential cred = new BasicX509Credential(credential.getCertificate());
382380
cred.setUsageType(UsageType.SIGNING);
383-
cred.setEntityId(token.getIdpEntityId());
381+
cred.setEntityId(relyingParty.getProviderDetails().getEntityId());
384382
credentials.add(cred);
385383
}
386384
CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
@@ -390,13 +388,15 @@ private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken
390388
);
391389
}
392390

393-
private ValidationContext buildValidationContext(Saml2AuthenticationToken token, Response response) {
391+
private ValidationContext buildValidationContext(RelyingPartyRegistration relyingParty, Response response) {
394392
Map<String, Object> validationParams = new HashMap<>();
395393
validationParams.put(SIGNATURE_REQUIRED, !response.isSigned());
396394
validationParams.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis());
397-
validationParams.put(COND_VALID_AUDIENCES, singleton(token.getLocalSpEntityId()));
398-
if (StringUtils.hasText(token.getRecipientUri())) {
399-
validationParams.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, singleton(token.getRecipientUri()));
395+
validationParams.put(COND_VALID_AUDIENCES, singleton(relyingParty.getEntityId()));
396+
String assertionConsumerServiceLocation = relyingParty.getAssertionConsumerServiceLocation();
397+
if (StringUtils.hasText(assertionConsumerServiceLocation)) {
398+
validationParams.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS,
399+
singleton(assertionConsumerServiceLocation));
400400
}
401401
return new ValidationContext(validationParams);
402402
}
@@ -422,11 +422,11 @@ private Assertion validateAssertion(Assertion assertion,
422422
return assertion;
423423
}
424424

425-
private Assertion decrypt(Saml2AuthenticationToken token, EncryptedAssertion assertion)
425+
private Assertion decrypt(RelyingPartyRegistration relyingParty, EncryptedAssertion assertion)
426426
throws Saml2AuthenticationException {
427427

428428
Saml2AuthenticationException last = null;
429-
List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
429+
Collection<Saml2X509Credential> decryptionCredentials = relyingParty.getDecryptionX509Credentials();
430430
if (decryptionCredentials.isEmpty()) {
431431
throw authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
432432
}
@@ -450,27 +450,7 @@ private Decrypter getDecrypter(Saml2X509Credential key) {
450450
return decrypter;
451451
}
452452

453-
private List<Saml2X509Credential> getDecryptionCredentials(Saml2AuthenticationToken token) {
454-
List<Saml2X509Credential> result = new LinkedList<>();
455-
for (Saml2X509Credential c : token.getX509Credentials()) {
456-
if (c.isDecryptionCredential()) {
457-
result.add(c);
458-
}
459-
}
460-
return result;
461-
}
462-
463-
private List<X509Certificate> getVerificationCertificates(Saml2AuthenticationToken token) {
464-
List<X509Certificate> result = new LinkedList<>();
465-
for (Saml2X509Credential c : token.getX509Credentials()) {
466-
if (c.isSignatureVerficationCredential()) {
467-
result.add(c.getCertificate());
468-
}
469-
}
470-
return result;
471-
}
472-
473-
private String getUsername(Saml2AuthenticationToken token, Assertion assertion)
453+
private String getUsername(RelyingPartyRegistration relyingParty, Assertion assertion)
474454
throws Saml2AuthenticationException {
475455

476456
String username = null;
@@ -482,7 +462,7 @@ private String getUsername(Saml2AuthenticationToken token, Assertion assertion)
482462
username = subject.getNameID().getValue();
483463
}
484464
else if (subject.getEncryptedID() != null) {
485-
NameID nameId = decrypt(token, subject.getEncryptedID());
465+
NameID nameId = decrypt(relyingParty, subject.getEncryptedID());
486466
username = nameId.getValue();
487467
}
488468
if (username == null) {
@@ -491,11 +471,11 @@ else if (subject.getEncryptedID() != null) {
491471
return username;
492472
}
493473

494-
private NameID decrypt(Saml2AuthenticationToken token, EncryptedID assertion)
474+
private NameID decrypt(RelyingPartyRegistration relyingParty, EncryptedID assertion)
495475
throws Saml2AuthenticationException {
496476

497477
Saml2AuthenticationException last = null;
498-
List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
478+
Collection<Saml2X509Credential> decryptionCredentials = relyingParty.getDecryptionX509Credentials();
499479
if (decryptionCredentials.isEmpty()) {
500480
throw authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
501481
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticationToken.java

Lines changed: 85 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616

1717
package org.springframework.security.saml2.provider.service.authentication;
1818

19+
import java.util.Collections;
20+
import java.util.List;
21+
1922
import org.springframework.security.authentication.AbstractAuthenticationToken;
2023
import org.springframework.security.saml2.credentials.Saml2X509Credential;
21-
22-
import java.util.List;
24+
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
25+
import org.springframework.util.Assert;
2326

2427
/**
2528
* Represents an incoming SAML 2.0 response containing an assertion that has not been validated.
@@ -28,11 +31,56 @@
2831
*/
2932
public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
3033

34+
private final RelyingPartyRegistration relyingPartyRegistration;
3135
private final String saml2Response;
32-
private final String recipientUri;
33-
private String idpEntityId;
34-
private String localSpEntityId;
35-
private List<Saml2X509Credential> credentials;
36+
37+
/**
38+
* Construct a {@link Saml2AuthenticationToken} with the provided parameters
39+
*
40+
* <p>
41+
* Note that the {@link RelyingPartyRegistration} should have any placeholders resolved to be included
42+
* in this token. This can be achieved with {@link RelyingPartyRegistration#withRelyingPartyRegistration}:
43+
*
44+
* <pre>
45+
* RelyingPartyRegistration resolved = withRelyingPartyRegistration(unresolved)
46+
* .entityId(...)
47+
* .assertionConsumerServiceLocation(...)
48+
* .build();
49+
* Saml2AuthenticationToken token = new Saml2AuthenticationToken(resolved, saml2Response);
50+
* </pre>
51+
*
52+
* @param relyingPartyRegistration The {@link RelyingPartyRegistration} associated with this token
53+
* @param saml2Response The serialized SAML 2.0 Response associated with this token
54+
* @since 5.4
55+
*/
56+
public Saml2AuthenticationToken(RelyingPartyRegistration relyingPartyRegistration, String saml2Response) {
57+
super(Collections.emptyList());
58+
Assert.isTrue(isResolved(relyingPartyRegistration.getAssertionConsumerServiceLocation()),
59+
"relyingPartyRegistration must have its placeholders resolved for inclusion in this token");
60+
Assert.isTrue(isResolved(relyingPartyRegistration.getEntityId()),
61+
"relyingPartyRegistration must have its placeholders resolved for inclusion in this token");
62+
this.relyingPartyRegistration = relyingPartyRegistration;
63+
this.saml2Response = saml2Response;
64+
}
65+
66+
private static boolean isResolved(String template) {
67+
if (template.contains("{registrationId}")) {
68+
return false;
69+
}
70+
if (template.contains("{baseUrl}")) {
71+
return false;
72+
}
73+
if (template.contains("{baseScheme}")) {
74+
return false;
75+
}
76+
if (template.contains("{baseHost}")) {
77+
return false;
78+
}
79+
if (template.contains("{basePort}")) {
80+
return false;
81+
}
82+
return true;
83+
}
3684

3785
/**
3886
* Creates an authentication token from an incoming SAML 2 Response object
@@ -41,18 +89,31 @@ public class Saml2AuthenticationToken extends AbstractAuthenticationToken {
4189
* @param idpEntityId the entity ID of the asserting entity
4290
* @param localSpEntityId the configured local SP, the relying party, entity ID
4391
* @param credentials the credentials configured for signature verification and decryption
92+
* @deprecated Use {@link Saml2AuthenticationToken(RelyingPartyRegistration, String)} instead
4493
*/
94+
@Deprecated
4595
public Saml2AuthenticationToken(String saml2Response,
4696
String recipientUri,
4797
String idpEntityId,
4898
String localSpEntityId,
4999
List<Saml2X509Credential> credentials) {
50-
super(null);
51-
this.saml2Response = saml2Response;
52-
this.recipientUri = recipientUri;
53-
this.idpEntityId = idpEntityId;
54-
this.localSpEntityId = localSpEntityId;
55-
this.credentials = credentials;
100+
this(RelyingPartyRegistration.withRegistrationId(localSpEntityId)
101+
.entityId(localSpEntityId)
102+
.credentials(c -> c.addAll(credentials))
103+
.assertionConsumerServiceLocation(recipientUri)
104+
.providerDetails(ap -> ap.entityId(idpEntityId))
105+
.build(),
106+
saml2Response);
107+
}
108+
109+
/**
110+
* Get the {@link RelyingPartyRegistration} associated with this authentication token
111+
*
112+
* @return the {@link RelyingPartyRegistration} associated with this authentication token
113+
* @since 5.4
114+
*/
115+
public RelyingPartyRegistration getRelyingPartyRegistration() {
116+
return this.relyingPartyRegistration;
56117
}
57118

58119
/**
@@ -84,25 +145,31 @@ public String getSaml2Response() {
84145
/**
85146
* Returns the URI that the SAML 2 Response object came in on
86147
* @return URI as a string
148+
* @deprecated Use {@link #getRelyingPartyRegistration().getAssertionConsumerServiceUrlTemplate()} instead
87149
*/
150+
@Deprecated
88151
public String getRecipientUri() {
89-
return this.recipientUri;
152+
return this.relyingPartyRegistration.getAssertionConsumerServiceLocation();
90153
}
91154

92155
/**
93156
* Returns the configured entity ID of the receiving relying party, SP
94157
* @return an entityID for the configured local relying party
158+
* @deprecated Use {@link #getRelyingPartyRegistration().getEntityId()} instead
95159
*/
160+
@Deprecated
96161
public String getLocalSpEntityId() {
97-
return this.localSpEntityId;
162+
return this.relyingPartyRegistration.getEntityId();
98163
}
99164

100165
/**
101166
* Returns all the credentials associated with the relying party configuraiton
102167
* @return
168+
* @deprecated Use {@link #getRelyingPartyRegistration()} instead
103169
*/
170+
@Deprecated
104171
public List<Saml2X509Credential> getX509Credentials() {
105-
return this.credentials;
172+
return this.relyingPartyRegistration.getCredentials();
106173
}
107174

108175
/**
@@ -126,8 +193,10 @@ public void setAuthenticated(boolean authenticated) {
126193
/**
127194
* Returns the configured IDP, asserting party, entity ID
128195
* @return a string representing the entity ID
196+
* @deprecated Use {@link #getRelyingPartyRegistration().getProviderDetails.getEntityId()} instead
129197
*/
198+
@Deprecated
130199
public String getIdpEntityId() {
131-
return this.idpEntityId;
200+
return this.relyingPartyRegistration.getProviderDetails().getEntityId();
132201
}
133202
}

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import static java.nio.charset.StandardCharsets.UTF_8;
3737
import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.RELYING_PARTY_REGISTRATION_NOT_FOUND;
38+
import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
3839
import static org.springframework.util.StringUtils.hasText;
3940

4041
/**
@@ -90,22 +91,23 @@ public Authentication attemptAuthentication(HttpServletRequest request, HttpServ
9091

9192
String responseXml = inflateIfRequired(request, b);
9293
String registrationId = this.matcher.matcher(request).getVariables().get("registrationId");
93-
RelyingPartyRegistration rp =
94+
RelyingPartyRegistration relyingPartyRegistration =
9495
this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
95-
if (rp == null) {
96+
if (relyingPartyRegistration == null) {
9697
Saml2Error saml2Error = new Saml2Error(RELYING_PARTY_REGISTRATION_NOT_FOUND,
9798
"Relying Party Registration not found with ID: " + registrationId);
9899
throw new Saml2AuthenticationException(saml2Error);
99100
}
100101
String applicationUri = Saml2ServletUtils.getApplicationUri(request);
101-
String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(rp.getEntityId(), applicationUri, rp);
102-
final Saml2AuthenticationToken authentication = new Saml2AuthenticationToken(
103-
responseXml,
104-
request.getRequestURL().toString(),
105-
rp.getProviderDetails().getEntityId(),
106-
localSpEntityId,
107-
rp.getCredentials()
108-
);
102+
String localSpEntityId = Saml2ServletUtils.resolveUrlTemplate(
103+
relyingPartyRegistration.getEntityId(), applicationUri, relyingPartyRegistration);
104+
RelyingPartyRegistration resolvedRelyingPartyRegistration =
105+
withRelyingPartyRegistration(relyingPartyRegistration)
106+
.assertionConsumerServiceLocation(request.getRequestURL().toString())
107+
.entityId(localSpEntityId)
108+
.build();
109+
Saml2AuthenticationToken authentication =
110+
new Saml2AuthenticationToken(resolvedRelyingPartyRegistration, responseXml);
109111
return getAuthenticationManager().authenticate(authentication);
110112
}
111113

0 commit comments

Comments
 (0)