15
15
*/
16
16
package org .springframework .security .saml2 .provider .service .authentication ;
17
17
18
- import java .security .cert .X509Certificate ;
19
18
import java .time .Duration ;
20
19
import java .time .Instant ;
21
20
import java .util .ArrayList ;
24
23
import java .util .HashMap ;
25
24
import java .util .HashSet ;
26
25
import java .util .LinkedHashMap ;
27
- import java .util .LinkedList ;
28
26
import java .util .List ;
29
27
import java .util .Map ;
30
28
import java .util .Set ;
38
36
import org .opensaml .core .xml .XMLObject ;
39
37
import org .opensaml .core .xml .config .XMLObjectProviderRegistrySupport ;
40
38
import org .opensaml .core .xml .io .Marshaller ;
41
-
42
39
import org .opensaml .core .xml .schema .XSAny ;
43
40
import org .opensaml .core .xml .schema .XSBoolean ;
44
41
import org .opensaml .core .xml .schema .XSBooleanValue ;
95
92
import org .springframework .security .core .authority .mapping .GrantedAuthoritiesMapper ;
96
93
import org .springframework .security .saml2 .Saml2Exception ;
97
94
import org .springframework .security .saml2 .credentials .Saml2X509Credential ;
95
+ import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration ;
98
96
import org .springframework .util .Assert ;
99
97
import org .springframework .util .StringUtils ;
100
98
@@ -218,9 +216,9 @@ public Authentication authenticate(Authentication authentication) throws Authent
218
216
try {
219
217
Saml2AuthenticationToken token = (Saml2AuthenticationToken ) authentication ;
220
218
Response response = parse (token .getSaml2Response ());
221
- List <Assertion > validAssertions = validateResponse (token , response );
219
+ List <Assertion > validAssertions = validateResponse (token . getRelyingPartyRegistration () , response );
222
220
Assertion assertion = validAssertions .get (0 );
223
- String username = getUsername (token , assertion );
221
+ String username = getUsername (token . getRelyingPartyRegistration () , assertion );
224
222
Map <String , List <Object >> attributes = getAssertionAttributes (assertion );
225
223
return new Saml2Authentication (
226
224
new SimpleSaml2AuthenticatedPrincipal (username , attributes ), token .getSaml2Response (),
@@ -259,7 +257,7 @@ private Response parse(String response) throws Saml2Exception, Saml2Authenticati
259
257
260
258
}
261
259
262
- private List <Assertion > validateResponse (Saml2AuthenticationToken token , Response response )
260
+ private List <Assertion > validateResponse (RelyingPartyRegistration relyingParty , Response response )
263
261
throws Saml2AuthenticationException {
264
262
265
263
List <Assertion > validAssertions = new ArrayList <>();
@@ -270,7 +268,7 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
270
268
271
269
List <Assertion > assertions = new ArrayList <>(response .getAssertions ());
272
270
for (EncryptedAssertion encryptedAssertion : response .getEncryptedAssertions ()) {
273
- Assertion assertion = decrypt (token , encryptedAssertion );
271
+ Assertion assertion = decrypt (relyingParty , encryptedAssertion );
274
272
assertions .add (assertion );
275
273
}
276
274
if (assertions .isEmpty ()) {
@@ -282,7 +280,7 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
282
280
"Please either sign the response or all of the assertions." );
283
281
}
284
282
285
- SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine (token );
283
+ SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine (relyingParty );
286
284
287
285
Map <String , Saml2AuthenticationException > validationExceptions = new HashMap <>();
288
286
if (response .isSigned ()) {
@@ -310,18 +308,18 @@ private List<Assertion> validateResponse(Saml2AuthenticationToken token, Respons
310
308
}
311
309
312
310
String destination = response .getDestination ();
313
- if (StringUtils .hasText (destination ) && !destination .equals (token . getRecipientUri ())) {
311
+ if (StringUtils .hasText (destination ) && !destination .equals (relyingParty . getAssertionConsumerServiceLocation ())) {
314
312
String message = "Invalid destination [" + destination + "] for SAML response [" + response .getID () + "]" ;
315
313
validationExceptions .put (INVALID_DESTINATION , authException (INVALID_DESTINATION , message ));
316
314
}
317
315
318
- if (!StringUtils .hasText (issuer ) || !issuer .equals (token . getIdpEntityId ())) {
316
+ if (!StringUtils .hasText (issuer ) || !issuer .equals (relyingParty . getProviderDetails (). getEntityId ())) {
319
317
String message = String .format ("Invalid issuer [%s] for SAML response [%s]" , issuer , response .getID ());
320
318
validationExceptions .put (INVALID_ISSUER , authException (INVALID_ISSUER , message ));
321
319
}
322
320
323
321
SAML20AssertionValidator validator = buildSamlAssertionValidator (signatureTrustEngine );
324
- ValidationContext context = buildValidationContext (token , response );
322
+ ValidationContext context = buildValidationContext (relyingParty , response );
325
323
326
324
if (logger .isDebugEnabled ()) {
327
325
logger .debug ("Validating " + assertions .size () + " assertions" );
@@ -375,12 +373,12 @@ private boolean isSigned(Response samlResponse, List<Assertion> assertions) {
375
373
return true ;
376
374
}
377
375
378
- private SignatureTrustEngine buildSignatureTrustEngine (Saml2AuthenticationToken token ) {
376
+ private SignatureTrustEngine buildSignatureTrustEngine (RelyingPartyRegistration relyingParty ) {
379
377
Set <Credential > credentials = new HashSet <>();
380
- for (X509Certificate key : getVerificationCertificates ( token )) {
381
- BasicX509Credential cred = new BasicX509Credential (key );
378
+ for (Saml2X509Credential credential : relyingParty . getProviderDetails (). getVerificationX509Credentials ( )) {
379
+ BasicX509Credential cred = new BasicX509Credential (credential . getCertificate () );
382
380
cred .setUsageType (UsageType .SIGNING );
383
- cred .setEntityId (token . getIdpEntityId ());
381
+ cred .setEntityId (relyingParty . getProviderDetails (). getEntityId ());
384
382
credentials .add (cred );
385
383
}
386
384
CredentialResolver credentialsResolver = new CollectionCredentialResolver (credentials );
@@ -390,13 +388,15 @@ private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken
390
388
);
391
389
}
392
390
393
- private ValidationContext buildValidationContext (Saml2AuthenticationToken token , Response response ) {
391
+ private ValidationContext buildValidationContext (RelyingPartyRegistration relyingParty , Response response ) {
394
392
Map <String , Object > validationParams = new HashMap <>();
395
393
validationParams .put (SIGNATURE_REQUIRED , !response .isSigned ());
396
394
validationParams .put (CLOCK_SKEW , this .responseTimeValidationSkew .toMillis ());
397
- validationParams .put (COND_VALID_AUDIENCES , singleton (token .getLocalSpEntityId ()));
398
- if (StringUtils .hasText (token .getRecipientUri ())) {
399
- validationParams .put (SAML2AssertionValidationParameters .SC_VALID_RECIPIENTS , singleton (token .getRecipientUri ()));
395
+ validationParams .put (COND_VALID_AUDIENCES , singleton (relyingParty .getEntityId ()));
396
+ String assertionConsumerServiceLocation = relyingParty .getAssertionConsumerServiceLocation ();
397
+ if (StringUtils .hasText (assertionConsumerServiceLocation )) {
398
+ validationParams .put (SAML2AssertionValidationParameters .SC_VALID_RECIPIENTS ,
399
+ singleton (assertionConsumerServiceLocation ));
400
400
}
401
401
return new ValidationContext (validationParams );
402
402
}
@@ -422,11 +422,11 @@ private Assertion validateAssertion(Assertion assertion,
422
422
return assertion ;
423
423
}
424
424
425
- private Assertion decrypt (Saml2AuthenticationToken token , EncryptedAssertion assertion )
425
+ private Assertion decrypt (RelyingPartyRegistration relyingParty , EncryptedAssertion assertion )
426
426
throws Saml2AuthenticationException {
427
427
428
428
Saml2AuthenticationException last = null ;
429
- List <Saml2X509Credential > decryptionCredentials = getDecryptionCredentials ( token );
429
+ Collection <Saml2X509Credential > decryptionCredentials = relyingParty . getDecryptionX509Credentials ( );
430
430
if (decryptionCredentials .isEmpty ()) {
431
431
throw authException (DECRYPTION_ERROR , "No valid decryption credentials found." );
432
432
}
@@ -450,27 +450,7 @@ private Decrypter getDecrypter(Saml2X509Credential key) {
450
450
return decrypter ;
451
451
}
452
452
453
- private List <Saml2X509Credential > getDecryptionCredentials (Saml2AuthenticationToken token ) {
454
- List <Saml2X509Credential > result = new LinkedList <>();
455
- for (Saml2X509Credential c : token .getX509Credentials ()) {
456
- if (c .isDecryptionCredential ()) {
457
- result .add (c );
458
- }
459
- }
460
- return result ;
461
- }
462
-
463
- private List <X509Certificate > getVerificationCertificates (Saml2AuthenticationToken token ) {
464
- List <X509Certificate > result = new LinkedList <>();
465
- for (Saml2X509Credential c : token .getX509Credentials ()) {
466
- if (c .isSignatureVerficationCredential ()) {
467
- result .add (c .getCertificate ());
468
- }
469
- }
470
- return result ;
471
- }
472
-
473
- private String getUsername (Saml2AuthenticationToken token , Assertion assertion )
453
+ private String getUsername (RelyingPartyRegistration relyingParty , Assertion assertion )
474
454
throws Saml2AuthenticationException {
475
455
476
456
String username = null ;
@@ -482,7 +462,7 @@ private String getUsername(Saml2AuthenticationToken token, Assertion assertion)
482
462
username = subject .getNameID ().getValue ();
483
463
}
484
464
else if (subject .getEncryptedID () != null ) {
485
- NameID nameId = decrypt (token , subject .getEncryptedID ());
465
+ NameID nameId = decrypt (relyingParty , subject .getEncryptedID ());
486
466
username = nameId .getValue ();
487
467
}
488
468
if (username == null ) {
@@ -491,11 +471,11 @@ else if (subject.getEncryptedID() != null) {
491
471
return username ;
492
472
}
493
473
494
- private NameID decrypt (Saml2AuthenticationToken token , EncryptedID assertion )
474
+ private NameID decrypt (RelyingPartyRegistration relyingParty , EncryptedID assertion )
495
475
throws Saml2AuthenticationException {
496
476
497
477
Saml2AuthenticationException last = null ;
498
- List <Saml2X509Credential > decryptionCredentials = getDecryptionCredentials ( token );
478
+ Collection <Saml2X509Credential > decryptionCredentials = relyingParty . getDecryptionX509Credentials ( );
499
479
if (decryptionCredentials .isEmpty ()) {
500
480
throw authException (DECRYPTION_ERROR , "No valid decryption credentials found." );
501
481
}
0 commit comments