diff --git a/gradle/dependency-management.gradle b/gradle/dependency-management.gradle index 003c3b7c5..560f53007 100644 --- a/gradle/dependency-management.gradle +++ b/gradle/dependency-management.gradle @@ -32,5 +32,6 @@ dependencyManagement { dependency "com.squareup.okhttp3:mockwebserver:3.14.9" dependency "com.squareup.okhttp3:okhttp:3.14.9" dependency "com.jayway.jsonpath:json-path:2.4.0" + dependency "org.hsqldb:hsqldb:2.5.+" } } diff --git a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle index 6f053db2e..6d678e29b 100644 --- a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle +++ b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle @@ -10,6 +10,8 @@ dependencies { compile 'com.nimbusds:nimbus-jose-jwt' compile 'com.fasterxml.jackson.core:jackson-databind' + optional 'org.springframework:spring-jdbc' + testCompile 'org.springframework.security:spring-security-test' testCompile 'org.springframework:spring-webmvc' testCompile 'junit:junit' @@ -17,6 +19,8 @@ dependencies { testCompile 'org.mockito:mockito-core' testCompile 'com.jayway.jsonpath:json-path' + testRuntime 'org.hsqldb:hsqldb' + provided 'javax.servlet:javax.servlet-api' } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java new file mode 100644 index 000000000..ced7b8875 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java @@ -0,0 +1,554 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.jdbc.support.lob.DefaultLobHandler; +import org.springframework.jdbc.support.lob.LobCreator; +import org.springframework.jdbc.support.lob.LobHandler; +import org.springframework.lang.Nullable; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken2; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.nio.charset.StandardCharsets; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.sql.Types; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +/** + * A JDBC implementation of an {@link OAuth2AuthorizationService} that uses a + *

+ * {@link JdbcOperations} for {@link OAuth2Authorization} persistence. + * + *

+ * NOTE: This {@code OAuth2AuthorizationService} depends on the table definition + * described in + * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql" and + * therefore MUST be defined in the database schema. + * + * @author Ovidiu Popa + * @see OAuth2AuthorizationService + * @see OAuth2Authorization + * @see JdbcOperations + * @see RowMapper + * @since 0.1.2 + */ +public final class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationService { + + // @formatter:off + private static final String COLUMN_NAMES = "id, " + + "registered_client_id, " + + "principal_name, " + + "authorization_grant_type, " + + "attributes, " + + "state, " + + "authorization_code_value, " + + "authorization_code_issued_at, " + + "authorization_code_expires_at," + + "authorization_code_metadata," + + "access_token_value," + + "access_token_issued_at," + + "access_token_expires_at," + + "access_token_metadata," + + "access_token_type," + + "access_token_scopes," + + "oidc_id_token_value," + + "oidc_id_token_issued_at," + + "oidc_id_token_expires_at," + + "oidc_id_token_metadata," + + "refresh_token_value," + + "refresh_token_issued_at," + + "refresh_token_expires_at," + + "refresh_token_metadata"; + // @formatter:on + + private static final String TABLE_NAME = "oauth2_authorization"; + + private static final String PK_FILTER = "id = ?"; + private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR " + + "access_token_value = ? OR " + + "refresh_token_value = ?"; + + private static final String STATE_FILTER = "state = ?"; + private static final String AUTHORIZATION_CODE_FILTER = "authorization_code_value = ?"; + private static final String ACCESS_TOKEN_FILTER = "access_token_value = ?"; + private static final String REFRESH_TOKEN_FILTER = "refresh_token_value = ?"; + + // @formatter:off + private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE "; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?)"; + // @formatter:on + + // @formatter:off + private static final String UPDATE_AUTHORIZATION_SQL = "UPDATE " + TABLE_NAME + + " SET registered_client_id = ?, principal_name = ?, authorization_grant_type = ?, attributes = ?, state = ?," + + " authorization_code_value = ?, authorization_code_issued_at = ?, authorization_code_expires_at = ?, authorization_code_metadata = ?," + + " access_token_value = ?, access_token_issued_at = ?, access_token_expires_at = ?, access_token_metadata = ?, access_token_type = ?, access_token_scopes = ?," + + " oidc_id_token_value = ?, oidc_id_token_issued_at = ?, oidc_id_token_expires_at = ?, oidc_id_token_metadata = ?," + + " refresh_token_value = ?, refresh_token_issued_at = ?, refresh_token_expires_at = ?, refresh_token_metadata = ?" + + " WHERE " + PK_FILTER; + // @formatter:on + + private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + private final JdbcOperations jdbcOperations; + private final LobHandler lobHandler; + private RowMapper authorizationRowMapper; + private Function> authorizationParametersMapper; + + /** + * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters. + * + * @param jdbcOperations the JDBC operations + * @param registeredClientRepository the registered client repository + */ + public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository) { + this(jdbcOperations, registeredClientRepository, new DefaultLobHandler()); + } + + /** + * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters. + * + * @param jdbcOperations the JDBC operations + * @param registeredClientRepository the registered client repository + * @param lobHandler the handler for large binary fields and large text fields + */ + public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository, LobHandler lobHandler) { + this(jdbcOperations, registeredClientRepository, lobHandler, new ObjectMapper()); + } + + /** + * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters. + * + * @param jdbcOperations the JDBC operations + * @param registeredClientRepository the registered client repository + * @param lobHandler the handler for large binary fields and large text fields + * @param objectMapper the object mapper + */ + public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository, + LobHandler lobHandler, ObjectMapper objectMapper) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(lobHandler, "lobHandler cannot be null"); + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.jdbcOperations = jdbcOperations; + this.lobHandler = lobHandler; + OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository, objectMapper); + authorizationRowMapper.setLobHandler(lobHandler); + this.authorizationRowMapper = authorizationRowMapper; + this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper(objectMapper); + } + + + @Override + public void save(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + + OAuth2Authorization existingAuthorization = findById(authorization.getId()); + if (existingAuthorization == null) { + insertAuthorization(authorization); + } else { + updateAuthorization(authorization); + } + } + + private void updateAuthorization(OAuth2Authorization authorization) { + List parameters = this.authorizationParametersMapper.apply(authorization); + SqlParameterValue id = parameters.remove(0); + parameters.add(id); + try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { + PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, + parameters.toArray()); + this.jdbcOperations.update(UPDATE_AUTHORIZATION_SQL, pss); + } + } + + private void insertAuthorization(OAuth2Authorization authorization) { + List parameters = this.authorizationParametersMapper.apply(authorization); + try (LobCreator lobCreator = this.lobHandler.getLobCreator()) { + PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, + parameters.toArray()); + this.jdbcOperations.update(SAVE_AUTHORIZATION_SQL, pss); + } + } + + @Override + public void remove(OAuth2Authorization authorization) { + Assert.notNull(authorization, "authorization cannot be null"); + SqlParameterValue[] parameters = new SqlParameterValue[]{ + new SqlParameterValue(Types.VARCHAR, authorization.getId()) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss); + } + + @Nullable + @Override + public OAuth2Authorization findById(String id) { + Assert.hasText(id, "id cannot be empty"); + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue(Types.VARCHAR, id)); + return findBy(PK_FILTER, parameters); + } + + @Nullable + @Override + public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) { + Assert.hasText(token, "token cannot be empty"); + List parameters = new ArrayList<>(); + if (tokenType == null) { + parameters.add(new SqlParameterValue(Types.VARCHAR, token)); + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters); + } else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) { + parameters.add(new SqlParameterValue(Types.VARCHAR, token)); + return findBy(STATE_FILTER, parameters); + } else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) { + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + return findBy(AUTHORIZATION_CODE_FILTER, parameters); + } else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) { + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + return findBy(ACCESS_TOKEN_FILTER, parameters); + } else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) { + parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8))); + return findBy(REFRESH_TOKEN_FILTER, parameters); + } + return null; + } + + private OAuth2Authorization findBy(String filter, List parameters) { + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + List result = this.jdbcOperations.query(LOAD_AUTHORIZATION_SQL + filter, pss, this.authorizationRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + /** + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link OAuth2Authorization}. The default is + * {@link OAuth2AuthorizationRowMapper}. + * + * @param authorizationRowMapper the {@link RowMapper} used for mapping the current + * row in {@code ResultSet} to {@link OAuth2Authorization} + */ + public void setAuthorizationRowMapper(RowMapper authorizationRowMapper) { + Assert.notNull(authorizationRowMapper, "authorizationRowMapper cannot be null"); + this.authorizationRowMapper = authorizationRowMapper; + } + + /** + * Sets the {@code Function} used for mapping {@link OAuth2Authorization} to + * a {@code List} of {@link SqlParameterValue}. The default is + * {@link OAuth2AuthorizationParametersMapper}. + * + * @param authorizationParametersMapper the {@code Function} used for mapping + * {@link OAuth2Authorization} to a {@code List} of {@link SqlParameterValue} + */ + public void setAuthorizationParametersMapper( + Function> authorizationParametersMapper) { + Assert.notNull(authorizationParametersMapper, "authorizationParametersMapper cannot be null"); + this.authorizationParametersMapper = authorizationParametersMapper; + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code java.sql.ResultSet} to {@link OAuth2Authorization}. + */ + public static class OAuth2AuthorizationRowMapper implements RowMapper { + + private final RegisteredClientRepository registeredClientRepository; + private final ObjectMapper objectMapper; + private LobHandler lobHandler = new DefaultLobHandler(); + + + public OAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository, ObjectMapper objectMapper) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.objectMapper = objectMapper; + } + + @Override + @SuppressWarnings("unchecked") + public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException { + try { + String registeredClientId = rs.getString("registered_client_id"); + RegisteredClient registeredClient = this.registeredClientRepository + .findById(registeredClientId); + if (registeredClient == null) { + throw new DataRetrievalFailureException( + "The RegisteredClient with id '" + registeredClientId + "' it was not found in the RegisteredClientRepository."); + } + + OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient); + String id = rs.getString("id"); + String principalName = rs.getString("principal_name"); + String authorizationGrantType = rs.getString("authorization_grant_type"); + Map attributes = this.objectMapper.readValue(rs.getString("attributes"), Map.class); + + builder.id(id) + .principalName(principalName) + .authorizationGrantType(new AuthorizationGrantType(authorizationGrantType)) + .attributes(attrs -> attrs.putAll(attributes)); + + String state = rs.getString("state"); + if (StringUtils.hasText(state)) { + builder.attribute(OAuth2ParameterNames.STATE, state); + } + + String tokenValue; + Instant tokenIssuedAt; + Instant tokenExpiresAt; + byte[] authorizationCodeValue = this.lobHandler.getBlobAsBytes(rs, "authorization_code_value"); + + if (authorizationCodeValue != null) { + tokenValue = new String(authorizationCodeValue, + StandardCharsets.UTF_8); + tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant(); + tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant(); + Map authorizationCodeMetadata = this.objectMapper.readValue(rs.getString("authorization_code_metadata"), Map.class); + + OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode( + tokenValue, tokenIssuedAt, tokenExpiresAt); + builder + .token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata)); + } + + byte[] accessTokenValue = this.lobHandler.getBlobAsBytes(rs, "access_token_value"); + if (accessTokenValue != null) { + tokenValue = new String(accessTokenValue, + StandardCharsets.UTF_8); + tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant(); + tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant(); + Map accessTokenMetadata = this.objectMapper.readValue(rs.getString("access_token_metadata"), Map.class); + OAuth2AccessToken.TokenType tokenType = null; + if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) { + tokenType = OAuth2AccessToken.TokenType.BEARER; + } + + Set scopes = Collections.emptySet(); + String accessTokenScopes = rs.getString("access_token_scopes"); + if (accessTokenScopes != null) { + scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes); + } + OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes); + builder + .token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata)); + } + + byte[] oidcIdTokenValue = this.lobHandler.getBlobAsBytes(rs, "oidc_id_token_value"); + + if (oidcIdTokenValue != null) { + tokenValue = new String(oidcIdTokenValue, + StandardCharsets.UTF_8); + tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant(); + tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant(); + Map oidcTokenMetadata = this.objectMapper.readValue(rs.getString("oidc_id_token_metadata"), Map.class); + + OidcIdToken oidcToken = new OidcIdToken( + tokenValue, tokenIssuedAt, tokenExpiresAt, (Map) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME)); + builder + .token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata)); + } + + byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value"); + if (refreshTokenValue != null) { + tokenValue = new String(refreshTokenValue, + StandardCharsets.UTF_8); + tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant(); + tokenExpiresAt = null; + Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at"); + if (refreshTokenExpiresAt != null) { + tokenExpiresAt = refreshTokenExpiresAt.toInstant(); + } + Map refreshTokenMetadata = this.objectMapper.readValue(rs.getString("refresh_token_metadata"), Map.class); + + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2( + tokenValue, tokenIssuedAt, tokenExpiresAt); + builder + .token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata)); + } + return builder.build(); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + } + + public final void setLobHandler(LobHandler lobHandler) { + Assert.notNull(lobHandler, "lobHandler cannot be null"); + this.lobHandler = lobHandler; + } + } + + /** + * The default {@code Function} that maps {@link OAuth2Authorization} to a + * {@code List} of {@link SqlParameterValue}. + */ + public static class OAuth2AuthorizationParametersMapper implements Function> { + private final ObjectMapper objectMapper; + + public OAuth2AuthorizationParametersMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "objectMapper cannot be null"); + this.objectMapper = objectMapper; + } + + @Override + public List apply(OAuth2Authorization authorization) { + + try { + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getRegisteredClientId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getPrincipalName())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue())); + + String attributes = this.objectMapper.writeValueAsString(authorization.getAttributes()); + parameters.add(new SqlParameterValue(Types.VARCHAR, attributes)); + + String state = null; + String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE); + if (StringUtils.hasText(authorizationState)) { + state = authorizationState; + } + parameters.add(new SqlParameterValue(Types.VARCHAR, state)); + + OAuth2Authorization.Token authorizationCode = + authorization.getToken(OAuth2AuthorizationCode.class); + List authorizationCodeSqlParameters = toSqlParameterList(authorizationCode); + parameters.addAll(authorizationCodeSqlParameters); + + OAuth2Authorization.Token accessToken = + authorization.getToken(OAuth2AccessToken.class); + List accessTokenSqlParameters = toSqlParameterList(accessToken); + parameters.addAll(accessTokenSqlParameters); + String accessTokenType = null; + String accessTokenScopes = null; + + if (accessToken != null) { + accessTokenType = accessToken.getToken().getTokenType().getValue(); + if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) { + accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ","); + } + } + + parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType)); + parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes)); + OAuth2Authorization.Token oidcIdToken = authorization.getToken(OidcIdToken.class); + List oidcTokenSqlParameters = toSqlParameterList(oidcIdToken); + parameters.addAll(oidcTokenSqlParameters); + + OAuth2Authorization.Token refreshToken = authorization.getRefreshToken(); + + List refreshTokenSqlParameters = toSqlParameterList(refreshToken); + parameters.addAll(refreshTokenSqlParameters); + return parameters; + } catch (JsonProcessingException e) { + throw new IllegalArgumentException(e.getMessage(), e); + } + + } + + private List toSqlParameterList(OAuth2Authorization.Token token) throws JsonProcessingException { + List parameters = new ArrayList<>(); + byte[] tokenValue = null; + Timestamp tokenIssuedAt = null; + Timestamp tokenExpiresAt = null; + String codeMetadata = null; + if (token != null) { + + tokenValue = token.getToken().getTokenValue().getBytes(StandardCharsets.UTF_8); + if (token.getToken().getIssuedAt() != null) { + tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt()); + } + + if (token.getToken().getExpiresAt() != null) { + tokenExpiresAt = Timestamp.from(token.getToken().getExpiresAt()); + } + codeMetadata = this.objectMapper.writeValueAsString(token.getMetadata()); + } + parameters.add(new SqlParameterValue(Types.BLOB, tokenValue)); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt)); + parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt)); + parameters.add(new SqlParameterValue(Types.VARCHAR, codeMetadata)); + return parameters; + } + } + + private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter { + + protected final LobCreator lobCreator; + + private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) { + super(args); + this.lobCreator = lobCreator; + } + + @Override + protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException { + if (argValue instanceof SqlParameterValue) { + SqlParameterValue paramValue = (SqlParameterValue) argValue; + if (paramValue.getSqlType() == Types.BLOB) { + if (paramValue.getValue() != null) { + Assert.isInstanceOf(byte[].class, paramValue.getValue(), + "Value of blob parameter must be byte[]"); + } + byte[] valueBytes = (byte[]) paramValue.getValue(); + this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes); + return; + } + } + super.doSetValue(ps, parameterPosition, argValue); + } + + } +} diff --git a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql new file mode 100644 index 000000000..c426c7682 --- /dev/null +++ b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql @@ -0,0 +1,27 @@ +CREATE TABLE oauth2_authorization ( + id varchar(100) NOT NULL, + registered_client_id varchar(100) NOT NULL, + principal_name varchar(200) NOT NULL, + authorization_grant_type varchar(100) NOT NULL, + attributes varchar(1000) DEFAULT NULL, + state varchar(1000) DEFAULT NULL, + authorization_code_value blob DEFAULT NULL, + authorization_code_issued_at timestamp DEFAULT NULL, + authorization_code_expires_at timestamp DEFAULT NULL, + authorization_code_metadata varchar(1000) DEFAULT NULL, + access_token_value blob DEFAULT NULL, + access_token_issued_at timestamp DEFAULT NULL, + access_token_expires_at timestamp DEFAULT NULL, + access_token_metadata varchar(1000) DEFAULT NULL, + access_token_type varchar(100) DEFAULT NULL, + access_token_scopes varchar(1000) DEFAULT NULL, + oidc_id_token_value blob DEFAULT NULL, + oidc_id_token_issued_at timestamp DEFAULT NULL, + oidc_id_token_expires_at timestamp DEFAULT NULL, + oidc_id_token_metadata varchar(1000) DEFAULT NULL, + refresh_token_value blob DEFAULT NULL, + refresh_token_issued_at timestamp DEFAULT NULL, + refresh_token_expires_at timestamp DEFAULT NULL, + refresh_token_metadata varchar(1000) DEFAULT NULL, + PRIMARY KEY (id) +); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java new file mode 100644 index 000000000..9dd4c42e5 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java @@ -0,0 +1,397 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.jdbc.support.lob.DefaultLobHandler; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken2; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link JdbcOAuth2AuthorizationService}. + * + * @author Ovidiu Popa + */ +public class JdbcOAuth2AuthorizationServiceTests { + private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql"; + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); + private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); + private static final String ID = "id"; + private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); + private static final String PRINCIPAL_NAME = "principal"; + private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE; + private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode( + "code", Instant.now().truncatedTo(ChronoUnit.MILLIS), Instant.now().plus(5, ChronoUnit.MINUTES).truncatedTo(ChronoUnit.MILLIS)); + + private EmbeddedDatabase db; + private JdbcOperations jdbcOperations; + private RegisteredClientRepository registeredClientRepository; + private JdbcOAuth2AuthorizationService authorizationService; + + + @Before + public void setUp() { + this.db = createDb(); + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.jdbcOperations = new JdbcTemplate(this.db); + this.authorizationService = new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository); + } + + @After + public void tearDown() { + this.db.shutdown(); + } + + @Test + public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(null, this.registeredClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenLobHandlerIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("lobHandler cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenObjectMapperIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository, new DefaultLobHandler(), null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("objectMapper cannot be null"); + // @formatter:on + } + + @Test + public void setAuthorizationRowMapperWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.setAuthorizationRowMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationRowMapper cannot be null"); + // @formatter:on + } + + @Test + public void setAuthorizationParametersMapperWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.setAuthorizationParametersMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationParametersMapper cannot be null"); + // @formatter:on + } + + @Test + public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.save(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorization cannot be null"); + // @formatter:on + } + + @Test + public void saveWhenAuthorizationNewThenSaved() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + this.authorizationService.save(expectedAuthorization); + + OAuth2Authorization authorization = this.authorizationService.findById(ID); + assertThat(authorization).isEqualTo(expectedAuthorization); + } + + @Test + public void saveWhenAuthorizationExistsThenUpdated() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + this.authorizationService.save(originalAuthorization); + + OAuth2Authorization authorization = this.authorizationService.findById( + originalAuthorization.getId()); + assertThat(authorization).isEqualTo(originalAuthorization); + + OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization) + .attribute("custom-name-1", "custom-value-1") + .build(); + this.authorizationService.save(updatedAuthorization); + + authorization = this.authorizationService.findById( + updatedAuthorization.getId()); + assertThat(authorization).isEqualTo(updatedAuthorization); + assertThat(authorization).isNotEqualTo(originalAuthorization); + } + + @Test + public void saveLoadAuthorizationWhenCustomStrategiesSetThenCalled() throws Exception { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + ObjectMapper objectMapper = new ObjectMapper(); + JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper authorizationRowMapper = spy( + new JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper( + this.registeredClientRepository, objectMapper)); + this.authorizationService.setAuthorizationRowMapper(authorizationRowMapper); + JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper authorizationParametersMapper = spy( + new JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper(objectMapper)); + this.authorizationService.setAuthorizationParametersMapper(authorizationParametersMapper); + + this.authorizationService.save(originalAuthorization); + OAuth2Authorization authorization = this.authorizationService.findById( + originalAuthorization.getId()); + assertThat(authorization).isEqualTo(originalAuthorization); + verify(authorizationRowMapper).mapRow(any(), anyInt()); + verify(authorizationParametersMapper).apply(any()); + } + + @Test + public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.remove(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorization cannot be null"); + // @formatter:on + } + + @Test + public void removeWhenAuthorizationProvidedThenRemoved() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + + this.authorizationService.save(expectedAuthorization); + OAuth2Authorization authorization = this.authorizationService.findByToken( + AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE); + assertThat(authorization).isEqualTo(expectedAuthorization); + + this.authorizationService.remove(expectedAuthorization); + authorization = this.authorizationService.findByToken( + AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE); + assertThat(authorization).isNull(); + } + + @Test + public void findByIdWhenIdNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.findById(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("id cannot be empty"); + // @formatter:on + } + + @Test + public void findByIdWhenIdEmptyThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.findById(" ")) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("id cannot be empty"); + // @formatter:on + } + + @Test + public void findByTokenWhenTokenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationService.findByToken(null, AUTHORIZATION_CODE_TOKEN_TYPE)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("token cannot be empty"); + // @formatter:on + } + + @Test + public void findByTokenWhenStateExistsThenFound() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + String state = "state"; + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .attribute(OAuth2ParameterNames.STATE, state) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + state, STATE_TOKEN_TYPE); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(state, null); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenWhenAuthorizationCodeExistsThenFound() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(AUTHORIZATION_CODE.getTokenValue(), null); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenWhenAccessTokenExistsThenFound() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + "access-token", Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS), Instant.now().truncatedTo(ChronoUnit.MILLIS)); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .token(AUTHORIZATION_CODE) + .accessToken(accessToken) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + accessToken.getTokenValue(), OAuth2TokenType.ACCESS_TOKEN); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(accessToken.getTokenValue(), null); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenWhenRefreshTokenExistsThenFound() { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2("refresh-token", + Instant.now().truncatedTo(ChronoUnit.MILLIS), + Instant.now().plus(5, ChronoUnit.MINUTES).truncatedTo(ChronoUnit.MILLIS)); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .refreshToken(refreshToken) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + refreshToken.getTokenValue(), OAuth2TokenType.REFRESH_TOKEN); + assertThat(authorization).isEqualTo(result); + result = this.authorizationService.findByToken(refreshToken.getTokenValue(), null); + assertThat(authorization).isEqualTo(result); + } + + @Test + public void findByTokenWhenWrongTokenTypeThenNotFound() { + OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now().truncatedTo(ChronoUnit.MILLIS)); + OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT) + .id(ID) + .principalName(PRINCIPAL_NAME) + .authorizationGrantType(AUTHORIZATION_GRANT_TYPE) + .refreshToken(refreshToken) + .build(); + this.authorizationService.save(authorization); + + OAuth2Authorization result = this.authorizationService.findByToken( + refreshToken.getTokenValue(), OAuth2TokenType.ACCESS_TOKEN); + assertThat(result).isNull(); + } + + @Test + public void findByTokenWhenTokenDoesNotExistThenNull() { + OAuth2Authorization result = this.authorizationService.findByToken( + "access-token", OAuth2TokenType.ACCESS_TOKEN); + assertThat(result).isNull(); + } + + private static EmbeddedDatabase createDb() { + return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE); + } + + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } +}