diff --git a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle index 8b40d415d61..b05c1bbd577 100644 --- a/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle +++ b/saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle @@ -106,6 +106,7 @@ dependencies { provided 'jakarta.servlet:jakarta.servlet-api' optional 'com.fasterxml.jackson.core:jackson-databind' + optional 'org.springframework:spring-jdbc' testImplementation 'com.squareup.okhttp3:mockwebserver' testImplementation "org.assertj:assertj-core" @@ -118,6 +119,7 @@ dependencies { testImplementation "org.springframework:spring-test" testRuntimeOnly 'org.junit.platform:junit-platform-launcher' + testRuntimeOnly 'org.hsqldb:hsqldb' } jar { diff --git a/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java new file mode 100644 index 00000000000..745d24072ae --- /dev/null +++ b/saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java @@ -0,0 +1,221 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.function.Function; + +import org.springframework.core.serializer.DefaultDeserializer; +import org.springframework.core.serializer.DefaultSerializer; +import org.springframework.core.serializer.Deserializer; +import org.springframework.core.serializer.Serializer; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.security.saml2.core.Saml2X509Credential; +import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails; +import org.springframework.util.Assert; +import org.springframework.util.function.ThrowingFunction; + +/** + * A JDBC implementation of {@link AssertingPartyMetadataRepository}. + * + * @author Cathy Wang + * @since 7.0 + */ +public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository { + + private final JdbcOperations jdbcOperations; + + private final RowMapper assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(); + + private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper(); + + // @formatter:off + static final String COLUMN_NAMES = "entity_id, " + + "singlesignon_url, " + + "singlesignon_binding, " + + "singlesignon_sign_request, " + + "signing_algorithms, " + + "verification_credentials, " + + "encryption_credentials, " + + "singlelogout_url, " + + "singlelogout_response_url, " + + "singlelogout_binding"; + // @formatter:on + + private static final String TABLE_NAME = "saml2_asserting_party_metadata"; + + private static final String ENTITY_ID_FILTER = "entity_id = ?"; + + // @formatter:off + private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + ENTITY_ID_FILTER; + + private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME; + // @formatter:on + + // @formatter:off + private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + // @formatter:on + + // @formatter:off + private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME + + " SET singlesignon_url = ?, " + + "singlesignon_binding = ?, " + + "singlesignon_sign_request = ?, " + + "signing_algorithms = ?, " + + "verification_credentials = ?, " + + "encryption_credentials = ?, " + + "singlelogout_url = ?, " + + "singlelogout_response_url = ?, " + + "singlelogout_binding = ?" + + " WHERE " + ENTITY_ID_FILTER; + // @formatter:on + + /** + * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided + * parameters. + * @param jdbcOperations the JDBC operations + */ + public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + this.jdbcOperations = jdbcOperations; + } + + @Override + public AssertingPartyMetadata findByEntityId(String entityId) { + Assert.hasText(entityId, "entityId cannot be empty"); + SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss, + this.assertingPartyMetadataRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + @Override + public Iterator iterator() { + List result = this.jdbcOperations.query(LOAD_ALL_SQL, + this.assertingPartyMetadataRowMapper); + return result.iterator(); + } + + /** + * Persist this {@link AssertingPartyMetadata} + * @param metadata the metadata to persist + */ + public void save(AssertingPartyMetadata metadata) { + Assert.notNull(metadata, "metadata cannot be null"); + int rows = updateCredentialRecord(metadata); + if (rows == 0) { + insertCredentialRecord(metadata); + } + } + + private void insertCredentialRecord(AssertingPartyMetadata metadata) { + List parameters = this.assertingPartyMetadataParametersMapper.apply(metadata); + this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray()); + } + + private int updateCredentialRecord(AssertingPartyMetadata metadata) { + List parameters = this.assertingPartyMetadataParametersMapper.apply(metadata); + SqlParameterValue credentialId = parameters.remove(0); + parameters.add(credentialId); + return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray()); + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. + */ + private static final class AssertingPartyMetadataRowMapper implements RowMapper { + + private final Deserializer deserializer = new DefaultDeserializer(); + + @Override + public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException { + String entityId = rs.getString("entity_id"); + String singleSignOnUrl = rs.getString("singlesignon_url"); + Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString("singlesignon_binding")); + boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request"); + String singleLogoutUrl = rs.getString("singlelogout_url"); + String singleLogoutResponseUrl = rs.getString("singlelogout_response_url"); + Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding")); + List algorithms = List.of(rs.getString("signing_algorithms").split(",")); + byte[] verificationCredentialsBytes = rs.getBytes("verification_credentials"); + byte[] encryptionCredentialsBytes = rs.getBytes("encryption_credentials"); + ThrowingFunction> credentials = ( + bytes) -> (Collection) this.deserializer.deserializeFromByteArray(bytes); + AssertingPartyMetadata.Builder builder = new AssertingPartyDetails.Builder(); + Collection verificationCredentials = credentials.apply(verificationCredentialsBytes); + Collection encryptionCredentials = (encryptionCredentialsBytes != null) + ? credentials.apply(encryptionCredentialsBytes) : List.of(); + + builder.entityId(entityId) + .wantAuthnRequestsSigned(singleSignOnSignRequest) + .singleSignOnServiceLocation(singleSignOnUrl) + .singleSignOnServiceBinding(singleSignOnBinding) + .singleLogoutServiceLocation(singleLogoutUrl) + .singleLogoutServiceBinding(singleLogoutBinding) + .singleLogoutServiceResponseLocation(singleLogoutResponseUrl) + .signingAlgorithms((a) -> a.addAll(algorithms)) + .verificationX509Credentials((c) -> c.addAll(verificationCredentials)) + .encryptionX509Credentials((c) -> c.addAll(encryptionCredentials)); + return builder.build(); + } + + } + + private static class AssertingPartyMetadataParametersMapper + implements Function> { + + private final Serializer serializer = new DefaultSerializer(); + + @Override + public List apply(AssertingPartyMetadata record) { + List parameters = new ArrayList<>(); + + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn())); + parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned())); + parameters.add(new SqlParameterValue(Types.BLOB, String.join(",", record.getSigningAlgorithms()))); + ThrowingFunction, byte[]> credentials = this.serializer::serializeToByteArray; + parameters + .add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials()))); + parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials()))); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation())); + parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn())); + + return parameters; + } + + } + +} diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql new file mode 100644 index 00000000000..ffa047fe7b6 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql @@ -0,0 +1,14 @@ +CREATE TABLE saml2_asserting_party_metadata +( + entity_id VARCHAR(1000) NOT NULL, + singlesignon_url VARCHAR(1000) NOT NULL, + singlesignon_binding VARCHAR(100), + singlesignon_sign_request boolean, + signing_algorithms BYTEA, + verification_credentials BYTEA NOT NULL, + encryption_credentials BYTEA, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(100), + PRIMARY KEY (entity_id) +); diff --git a/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql new file mode 100644 index 00000000000..e0897d85693 --- /dev/null +++ b/saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql @@ -0,0 +1,14 @@ +CREATE TABLE saml2_asserting_party_metadata +( + entity_id VARCHAR(1000) NOT NULL, + singlesignon_url VARCHAR(1000) NOT NULL, + singlesignon_binding VARCHAR(100), + singlesignon_sign_request boolean, + signing_algorithms VARCHAR(256) NOT NULL, + verification_credentials blob NOT NULL, + encryption_credentials blob, + singlelogout_url VARCHAR(1000), + singlelogout_response_url VARCHAR(1000), + singlelogout_binding VARCHAR(100), + PRIMARY KEY (entity_id) +); diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java new file mode 100644 index 00000000000..56f871a664b --- /dev/null +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java @@ -0,0 +1,148 @@ +/* + * Copyright 2002-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.saml2.provider.service.registration; + +import java.util.Iterator; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link JdbcAssertingPartyMetadataRepository} + */ +class JdbcAssertingPartyMetadataRepositoryTests { + + private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql"; + + private EmbeddedDatabase db; + + private JdbcAssertingPartyMetadataRepository repository; + + private JdbcOperations jdbcOperations; + + private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full() + .build() + .getAssertingPartyMetadata(); + + @BeforeEach + void setUp() { + this.db = createDb(); + this.jdbcOperations = new JdbcTemplate(this.db); + this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations); + } + + @AfterEach + void tearDown() { + this.db.shutdown(); + } + + @Test + void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> new JdbcAssertingPartyMetadataRepository(null)) + .withMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + void findByEntityIdWhenEntityIdIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.repository.findByEntityId(null)) + .withMessage("entityId cannot be empty"); + // @formatter:on + } + + @Test + void findByEntityIdWhenEntityPresentThenReturns() { + this.repository.save(this.metadata); + + AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId()); + + assertAssertingPartyEquals(found, this.metadata); + } + + @Test + void findByEntityIdWhenNotExistsThenNull() { + AssertingPartyMetadata found = this.repository.findByEntityId("non-existent-entity-id"); + assertThat(found).isNull(); + } + + @Test + void iteratorWhenEnitiesExistThenContains() { + AssertingPartyMetadata second = this.metadata.mutate().entityId("https://example.org/idp").build(); + this.repository.save(this.metadata); + this.repository.save(second); + + Iterator iterator = this.repository.iterator(); + + assertAssertingPartyEquals(iterator.next(), this.metadata); + assertAssertingPartyEquals(iterator.next(), second); + assertThat(iterator.hasNext()).isFalse(); + } + + @Test + void saveWhenExistingThenUpdates() { + this.repository.save(this.metadata); + boolean existing = this.metadata.getWantAuthnRequestsSigned(); + this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build()); + boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned(); + assertThat(existing).isNotEqualTo(updated); + } + + private static EmbeddedDatabase createDb() { + return createDb(SCHEMA_SQL_RESOURCE); + } + + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } + + private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) { + assertThat(found).isNotNull(); + assertThat(found.getEntityId()).isEqualTo(expected.getEntityId()); + assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation()); + assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding()); + assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned()); + assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation()); + assertThat(found.getSingleLogoutServiceResponseLocation()) + .isEqualTo(expected.getSingleLogoutServiceResponseLocation()); + assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding()); + assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms()); + assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials()); + assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials()); + } + +} diff --git a/saml2/saml2-service-provider/src/test/resources/rsa.crt b/saml2/saml2-service-provider/src/test/resources/rsa.crt new file mode 100644 index 00000000000..aa147065ded --- /dev/null +++ b/saml2/saml2-service-provider/src/test/resources/rsa.crt @@ -0,0 +1,23 @@ +-----BEGIN CERTIFICATE----- +MIID1zCCAr+gAwIBAgIUCzQeKBMTO0iHVW3iKmZC41haqCowDQYJKoZIhvcNAQEL +BQAwezELMAkGA1UEBhMCWFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwI +Q2l0eU5hbWUxFDASBgNVBAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55 +U2VjdGlvbk5hbWUxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzA5MjAwODI5MDNa +Fw0zMzA5MTcwODI5MDNaMHsxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5h +bWUxETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkG +A1UECwwSQ29tcGFueVNlY3Rpb25OYW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwggEi +MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUfi4aaCotJZX6OSDjv6fxCCfc +ihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTkFPbeDxAl75Oa +PGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQXnIXTBL2o76V +nuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB8PFFhCZXW2/9 +TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0sburHUlOnPGUh +Qj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQjHcvET9jvAgMB +AAGjUzBRMB0GA1UdDgQWBBQjDr/1E/01pfLPD8uWF7gbaYL0TTAfBgNVHSMEGDAW +gBQjDr/1E/01pfLPD8uWF7gbaYL0TTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3 +DQEBCwUAA4IBAQAGjUuec0+0XNMCRDKZslbImdCAVsKsEWk6NpnUViDFAxL+KQuC +NW131UeHb9SCzMqRwrY4QI3nAwJQCmilL/hFM3ss4acn3WHu1yci/iKPUKeL1ec5 +kCFUmqX1NpTiVaytZ/9TKEr69SMVqNfQiuW5U1bIIYTqK8xo46WpM6YNNHO3eJK6 +NH0MW79Wx5ryi4i4C6afqYbVbx7tqcmy8CFeNxgZ0bFQ87SiwYXIj77b6sVYbu32 +doykBQgSHLcagWASPQ73m73CWUgo+7+EqSKIQqORbgmTLPmOUh99gFIx7jmjTyHm +NBszx1ZVWuIv3mWmp626Kncyc+LLM9tvgymx +-----END CERTIFICATE-----