16
16
17
17
package org .springframework .security .saml2 .provider .service .registration ;
18
18
19
- import java .io .IOException ;
20
- import java .sql .PreparedStatement ;
21
19
import java .sql .ResultSet ;
22
20
import java .sql .SQLException ;
23
21
import java .sql .Types ;
24
- import java .util .ArrayList ;
25
22
import java .util .Collection ;
26
23
import java .util .Iterator ;
27
24
import java .util .List ;
28
- import java .util .function .Function ;
25
+ import java .util .function .Consumer ;
29
26
30
27
import org .apache .commons .logging .Log ;
31
28
import org .apache .commons .logging .LogFactory ;
32
- import org .slf4j .Logger ;
33
- import org .slf4j .LoggerFactory ;
34
29
import org .springframework .core .log .LogMessage ;
35
30
import org .springframework .core .serializer .DefaultDeserializer ;
36
- import org .springframework .core .serializer .DefaultSerializer ;
37
31
import org .springframework .core .serializer .Deserializer ;
38
- import org .springframework .core .serializer .Serializer ;
39
32
import org .springframework .jdbc .core .ArgumentPreparedStatementSetter ;
40
33
import org .springframework .jdbc .core .JdbcOperations ;
41
34
import org .springframework .jdbc .core .PreparedStatementSetter ;
44
37
import org .springframework .security .saml2 .core .Saml2X509Credential ;
45
38
import org .springframework .security .saml2 .provider .service .registration .RelyingPartyRegistration .AssertingPartyDetails ;
46
39
import org .springframework .util .Assert ;
40
+ import org .springframework .util .StringUtils ;
47
41
48
42
/**
49
43
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
@@ -58,13 +52,9 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
58
52
private RowMapper <AssertingPartyMetadata > assertingPartyMetadataRowMapper =
59
53
new AssertingPartyMetadataRowMapper (ResultSet ::getBytes );
60
54
61
- private Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper =
62
- new AssertingPartyMetadataParametersMapper ();
63
-
64
- private final SetBytes setBytes = PreparedStatement ::setBytes ;
65
-
66
55
// @formatter:off
67
56
static final String COLUMN_NAMES = "entity_id, "
57
+ + "metadata_uri, "
68
58
+ "singlesignon_url, "
69
59
+ "singlesignon_binding, "
70
60
+ "singlesignon_sign_request, "
@@ -87,26 +77,6 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
87
77
88
78
private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
89
79
+ " FROM " + TABLE_NAME ;
90
-
91
- private static final String SAVE_SQL = "INSERT INTO " + TABLE_NAME + " ("
92
- + COLUMN_NAMES
93
- + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)" ;
94
- // @formatter:on
95
-
96
- private static final String DELETE_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + ENTITY_ID_FILTER ;
97
-
98
- // @formatter:off
99
- private static final String UPDATE_SQL = "UPDATE " + TABLE_NAME
100
- + " SET singlesignon_url = ?, " +
101
- "singlesignon_binding = ?, " +
102
- "singlesignon_sign_request = ?, " +
103
- "signing_algorithms = ?, " +
104
- "verification_credentials = ?, " +
105
- "encryption_credentials = ?, " +
106
- "singlelogout_url = ? ," +
107
- "singlelogout_response_url = ?, " +
108
- "singlelogout_binding = ?"
109
- + " WHERE " + ENTITY_ID_FILTER ;
110
80
// @formatter:on
111
81
112
82
/**
@@ -134,41 +104,6 @@ public void setAssertingPartyMetadataRowMapper(
134
104
this .assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper ;
135
105
}
136
106
137
- public void setAssertingPartyMetadataParametersMapper (Function <AssertingPartyMetadata , List <SqlParameterValue >> assertingPartyMetadataParametersMapper ) {
138
- Assert .notNull (assertingPartyMetadataParametersMapper , "assertingPartyMetadataParametersMapper cannot be null" );
139
- this .assertingPartyMetadataParametersMapper = assertingPartyMetadataParametersMapper ;
140
- }
141
-
142
- public void save (AssertingPartyMetadata metadata ) {
143
- Assert .notNull (metadata , "metadata cannot be null" );
144
- int rows = update (metadata );
145
- if (rows == 0 ) {
146
- insert (metadata );
147
- }
148
- }
149
-
150
- private void insert (AssertingPartyMetadata metadata ) {
151
- List <SqlParameterValue > parameters = this .assertingPartyMetadataParametersMapper .apply (metadata );
152
- PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
153
- this .jdbcOperations .update (SAVE_SQL , pss );
154
- }
155
-
156
- private int update (AssertingPartyMetadata metadata ) {
157
- List <SqlParameterValue > parameters = this .assertingPartyMetadataParametersMapper .apply (metadata );
158
- SqlParameterValue credentialId = parameters .remove (0 );
159
- parameters .add (credentialId );
160
- PreparedStatementSetter pss = new BlobArgumentPreparedStatementSetter (this .setBytes , parameters .toArray ());
161
- return this .jdbcOperations .update (UPDATE_SQL , pss );
162
- }
163
-
164
- public void delete (String entityId ) {
165
- Assert .notNull (entityId , "entityId cannot be null" );
166
- SqlParameterValue [] parameters = new SqlParameterValue []{
167
- new SqlParameterValue (Types .VARCHAR , entityId ),};
168
- PreparedStatementSetter pss = new ArgumentPreparedStatementSetter (parameters );
169
- this .jdbcOperations .update (DELETE_SQL , pss );
170
- }
171
-
172
107
@ Override
173
108
public AssertingPartyMetadata findByEntityId (String entityId ) {
174
109
Assert .hasText (entityId , "entityId cannot be empty" );
@@ -187,75 +122,6 @@ public Iterator<AssertingPartyMetadata> iterator() {
187
122
return result .iterator ();
188
123
}
189
124
190
- private static class AssertingPartyMetadataParametersMapper
191
- implements Function <AssertingPartyMetadata , List <SqlParameterValue >> {
192
-
193
- private final Logger logger = LoggerFactory .getLogger (AssertingPartyMetadataParametersMapper .class );
194
-
195
- private final Serializer <Object > serializer = new DefaultSerializer ();
196
-
197
- @ Override
198
- public List <SqlParameterValue > apply (AssertingPartyMetadata record ) {
199
- List <SqlParameterValue > parameters = new ArrayList <>();
200
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getEntityId ()));
201
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceLocation ()));
202
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleSignOnServiceBinding ().getUrn ()));
203
- parameters .add (new SqlParameterValue (Types .BOOLEAN , record .getWantAuthnRequestsSigned ()));
204
- try {
205
- parameters .add (new SqlParameterValue (Types .BLOB ,
206
- this .serializer .serializeToByteArray (record .getSigningAlgorithms ())));
207
- } catch (IOException ex ) {
208
- this .logger .debug ("Failed to serialize signing algorithms" , ex );
209
- throw new IllegalArgumentException (ex );
210
- }
211
- try {
212
- parameters .add (new SqlParameterValue (Types .BLOB ,
213
- this .serializer .serializeToByteArray (record .getVerificationX509Credentials ())));
214
- } catch (IOException ex ) {
215
- this .logger .debug ("Failed to serialize verification credentials" , ex );
216
- throw new IllegalArgumentException (ex );
217
- }
218
- try {
219
- parameters .add (new SqlParameterValue (Types .BLOB ,
220
- this .serializer .serializeToByteArray (record .getEncryptionX509Credentials ())));
221
- } catch (IOException ex ) {
222
- this .logger .debug ("Failed to serialize encryption credentials" , ex );
223
- throw new IllegalArgumentException (ex );
224
- }
225
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceLocation ()));
226
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceResponseLocation ()));
227
- parameters .add (new SqlParameterValue (Types .VARCHAR , record .getSingleLogoutServiceBinding ().getUrn ()));
228
- return parameters ;
229
- }
230
- }
231
-
232
- private static final class BlobArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
233
-
234
- private final SetBytes setBytes ;
235
-
236
- private BlobArgumentPreparedStatementSetter (SetBytes setBytes , Object [] args ) {
237
- super (args );
238
- this .setBytes = setBytes ;
239
- }
240
-
241
- @ Override
242
- protected void doSetValue (PreparedStatement ps , int parameterPosition , Object argValue ) throws SQLException {
243
- if (argValue instanceof SqlParameterValue paramValue ) {
244
- if (paramValue .getSqlType () == Types .BLOB ) {
245
- if (paramValue .getValue () != null ) {
246
- Assert .isInstanceOf (byte [].class , paramValue .getValue (),
247
- "Value of blob parameter must be byte[]" );
248
- }
249
- byte [] valueBytes = (byte []) paramValue .getValue ();
250
- this .setBytes .setBytes (ps , parameterPosition , valueBytes );
251
- return ;
252
- }
253
- }
254
- super .doSetValue (ps , parameterPosition , argValue );
255
- }
256
-
257
- }
258
-
259
125
/**
260
126
* The default {@link RowMapper} that maps the current row in
261
127
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
@@ -275,61 +141,68 @@ private final static class AssertingPartyMetadataRowMapper implements RowMapper<
275
141
@ Override
276
142
public AssertingPartyMetadata mapRow (ResultSet rs , int rowNum ) throws SQLException {
277
143
String entityId = rs .getString ("entity_id" );
144
+ String metadataUri = rs .getString ("metadata_uri" );
278
145
String singleSignOnUrl = rs .getString ("singlesignon_url" );
279
- Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding
280
- .from (rs .getString ("singlesignon_binding" ));
146
+ Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding .from (rs .getString ("singlesignon_binding" ));
281
147
boolean singleSignOnSignRequest = rs .getBoolean ("singlesignon_sign_request" );
282
- List <String > signingAlgorithms ;
283
- try {
284
- signingAlgorithms = (List <String >) deserializer .deserializeFromByteArray (
285
- this .getBytes .getBytes (rs , "signing_algorithms" ));
286
- } catch (IOException ex ) {
287
- this .logger .debug (
288
- LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
289
- return null ;
290
- }
291
- Collection <Saml2X509Credential > verificationCredentials ;
292
- try {
293
- verificationCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
294
- this .getBytes .getBytes (rs , "verification_credentials" ));
295
- } catch (IOException ex ) {
296
- this .logger .debug (
297
- LogMessage .format ("Verification credentials of %s could not be parsed." , entityId ), ex );
298
- return null ;
299
- }
300
- Collection <Saml2X509Credential > encryptionCredentials ;
148
+ String singleLogoutUrl = rs .getString ("singlelogout_url" );
149
+ String singleLogoutResponseUrl = rs .getString ("singlelogout_response_url" );
150
+ Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding .from (rs .getString ("singlelogout_binding" ));
151
+ byte [] signingAlgorithmsBytes = this .getBytes .getBytes (rs , "signing_algorithms" );
152
+ byte [] verificationCredentialsBytes = this .getBytes .getBytes (rs , "verification_credentials" );
153
+ byte [] encryptionCredentialsBytes = this .getBytes .getBytes (rs , "encryption_credentials" );
154
+
155
+ boolean usingMetadata = StringUtils .hasText (metadataUri );
156
+ AssertingPartyMetadata .Builder <?> builder = (!usingMetadata ) ? new AssertingPartyDetails .Builder ().entityId (entityId )
157
+ : createBuilderUsingMetadata (entityId , metadataUri );
301
158
try {
302
- encryptionCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (
303
- this .getBytes .getBytes (rs , "encryption_credentials" ));
304
- } catch (IOException ex ) {
159
+ if (signingAlgorithmsBytes != null ) {
160
+ List <String > signingAlgorithms = (List <String >) deserializer .deserializeFromByteArray (signingAlgorithmsBytes );
161
+ builder .signingAlgorithms (algorithms -> algorithms .addAll (signingAlgorithms ));
162
+ }
163
+ if (verificationCredentialsBytes != null ) {
164
+ Collection <Saml2X509Credential > verificationCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (verificationCredentialsBytes );
165
+ builder .verificationX509Credentials (credentials -> credentials .addAll (verificationCredentials ));
166
+ }
167
+ if (encryptionCredentialsBytes != null ) {
168
+ Collection <Saml2X509Credential > encryptionCredentials = (Collection <Saml2X509Credential >) deserializer .deserializeFromByteArray (encryptionCredentialsBytes );
169
+ builder .encryptionX509Credentials (credentials -> credentials .addAll (encryptionCredentials ));
170
+ }
171
+ } catch (Exception ex ) {
305
172
this .logger .debug (
306
- LogMessage .format ("Encryption credentials of %s could not be parsed. " , entityId ), ex );
173
+ LogMessage .format ("Parsing serialized credentials for entity %s failed " , entityId ), ex );
307
174
return null ;
308
175
}
309
- String singleLogoutUrl = rs .getString ("singlelogout_url" );
310
- String singleLogoutResponseUrl = rs .getString ("singlelogout_response_url" );
311
- Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding
312
- .from (rs .getString ("singlelogout_binding" ));
313
176
314
- return new AssertingPartyDetails .Builder ()
315
- .entityId (entityId )
316
- .wantAuthnRequestsSigned (singleSignOnSignRequest )
317
- .signingAlgorithms (algorithms -> algorithms .addAll (signingAlgorithms ))
318
- .verificationX509Credentials (credentials -> credentials .addAll (verificationCredentials ))
319
- .encryptionX509Credentials (credentials -> credentials .addAll (encryptionCredentials ))
320
- .singleSignOnServiceLocation (singleSignOnUrl )
321
- .singleSignOnServiceBinding (singleSignOnBinding )
322
- .singleLogoutServiceLocation (singleLogoutUrl )
323
- .singleLogoutServiceBinding (singleLogoutBinding )
324
- .singleLogoutServiceResponseLocation (singleLogoutResponseUrl )
325
- .build ();
177
+ applyingWhenNonNull (singleSignOnUrl , builder ::singleSignOnServiceLocation );
178
+ applyingWhenNonNull (singleSignOnBinding , builder ::singleSignOnServiceBinding );
179
+ applyingWhenNonNull (singleSignOnSignRequest , builder ::wantAuthnRequestsSigned );
180
+ applyingWhenNonNull (singleLogoutUrl , builder ::singleLogoutServiceLocation );
181
+ applyingWhenNonNull (singleLogoutResponseUrl , builder ::singleLogoutServiceResponseLocation );
182
+ applyingWhenNonNull (singleLogoutBinding , builder ::singleLogoutServiceBinding );
183
+ return builder .build ();
326
184
}
327
- }
328
185
329
- private interface SetBytes {
186
+ private <T > void applyingWhenNonNull (T value , Consumer <T > consumer ) {
187
+ if (value != null ) {
188
+ consumer .accept (value );
189
+ }
190
+ }
330
191
331
- void setBytes (PreparedStatement ps , int index , byte [] bytes ) throws SQLException ;
192
+ private AssertingPartyMetadata .Builder <?> createBuilderUsingMetadata (String entityId , String metadataUri ) {
193
+ Collection <AssertingPartyMetadata .Builder <?>> candidates = AssertingPartyMetadata
194
+ .collectionFromMetadataLocation (metadataUri );
195
+ for (AssertingPartyMetadata .Builder <?> candidate : candidates ) {
196
+ if (entityId == null || entityId .equals (getEntityId (candidate ))) {
197
+ return candidate ;
198
+ }
199
+ }
200
+ throw new IllegalStateException ("No asserting party metadata with Entity ID '" + entityId + "' found" );
201
+ }
332
202
203
+ private Object getEntityId (AssertingPartyMetadata .Builder <?> candidate ) {
204
+ return candidate .build ().getEntityId ();
205
+ }
333
206
}
334
207
335
208
private interface GetBytes {
0 commit comments