diff --git a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql similarity index 96% rename from oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql rename to oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql index 0b4a9925..aaa2fa2d 100644 --- a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql +++ b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql @@ -11,4 +11,5 @@ CREATE TABLE oauth2_registered_client ( scopes varchar(1000) NOT NULL, client_settings varchar(1000) DEFAULT NULL, token_settings varchar(1000) DEFAULT NULL, - PRIMARY KEY (id)); + PRIMARY KEY (id) +); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java index eb5cc7bc..97c40a1e 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java @@ -15,20 +15,37 @@ */ package org.springframework.security.oauth2.server.authorization.client; -import java.io.InputStream; -import java.nio.charset.Charset; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; import java.time.Duration; import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.jdbc.datasource.DriverManagerDataSource; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.util.StreamUtils; +import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -37,51 +54,38 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException * JDBC-backed registered client repository tests * * @author Rafal Lewczuk + * @author Steve Riesenberg * @since 0.1.2 */ public class JdbcRegisteredClientRepositoryTests { - private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql"; + private static final String REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql"; + private static final String CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql"; private DriverManagerDataSource dataSource; - private JdbcRegisteredClientRepository clients; + private JdbcRegisteredClientRepository registeredClientRepository; - private RegisteredClient registration; + private RegisteredClient registeredClient; - private JdbcTemplate jdbc; + private EmbeddedDatabase db; + + private JdbcOperations jdbcOperations; @Before public void setup() throws Exception { - this.dataSource = new DriverManagerDataSource(); - this.dataSource.setDriverClassName("org.hsqldb.jdbcDriver"); - this.dataSource.setUrl("jdbc:hsqldb:mem:oauthtest"); - this.dataSource.setUsername("sa"); - this.dataSource.setPassword(""); + this.db = createDb(REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE); + this.jdbcOperations = new JdbcTemplate(this.db); - this.jdbc = new JdbcTemplate(this.dataSource); + this.registeredClientRepository = new JdbcRegisteredClientRepository(this.jdbcOperations); + this.registeredClient = TestRegisteredClients.registeredClient().build(); - // execute scripts - try (InputStream is = JdbcRegisteredClientRepositoryTests.class.getResourceAsStream(SCRIPT)) { - assertThat(is).isNotNull().describedAs("Cannot open resource file: " + SCRIPT); - String ddls = StreamUtils.copyToString(is, Charset.defaultCharset()); - for (String ddl : ddls.split(";\n")) { - if (!ddl.trim().isEmpty()) { - this.jdbc.execute(ddl.trim()); - } - } - } - - this.clients = new JdbcRegisteredClientRepository(this.jdbc); - this.registration = TestRegisteredClients.registeredClient().build(); - - this.clients.save(this.registration); + this.registeredClientRepository.save(this.registeredClient); } @After public void destroyDatabase() { - this.jdbc.update("truncate table oauth2_registered_client"); - new JdbcTemplate(this.dataSource).execute("SHUTDOWN"); + this.db.shutdown(); } @Test @@ -97,7 +101,7 @@ public class JdbcRegisteredClientRepositoryTests { public void whenSetNullRegisteredClientRowMapperThenThrow() { // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.setRegisteredClientRowMapper(null)) + .isThrownBy(() -> this.registeredClientRepository.setRegisteredClientRowMapper(null)) .withMessage("registeredClientRowMapper cannot be null"); // @formatter:on } @@ -106,20 +110,20 @@ public class JdbcRegisteredClientRepositoryTests { public void whenSetNullRegisteredClientParameterMapperThenThrow() { // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.setRegisteredClientParametersMapper(null)) + .isThrownBy(() -> this.registeredClientRepository.setRegisteredClientParametersMapper(null)) .withMessage("registeredClientParameterMapper cannot be null"); // @formatter:on } @Test public void findByIdWhenFoundThenFound() { - String id = this.registration.getId(); - assertRegisteredClientIsEqualTo(this.clients.findById(id), this.registration); + String id = this.registeredClient.getId(); + assertRegisteredClientIsEqualTo(this.registeredClientRepository.findById(id), this.registeredClient); } @Test public void findByIdWhenNotFoundThenNull() { - RegisteredClient client = this.clients.findById("noooope"); + RegisteredClient client = this.registeredClientRepository.findById("noooope"); assertThat(client).isNull(); } @@ -127,20 +131,20 @@ public class JdbcRegisteredClientRepositoryTests { public void findByIdWhenNullThenThrowIllegalArgumentException() { // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.findById(null)) + .isThrownBy(() -> this.registeredClientRepository.findById(null)) .withMessage("id cannot be empty"); // @formatter:on } @Test public void findByClientIdWhenFoundThenFound() { - String id = this.registration.getClientId(); - assertRegisteredClientIsEqualTo(this.clients.findByClientId(id), this.registration); + String id = this.registeredClient.getClientId(); + assertRegisteredClientIsEqualTo(this.registeredClientRepository.findByClientId(id), this.registeredClient); } @Test public void findByClientIdWhenNotFoundThenNull() { - RegisteredClient client = this.clients.findByClientId("noooope"); + RegisteredClient client = this.registeredClientRepository.findByClientId("noooope"); assertThat(client).isNull(); } @@ -148,7 +152,7 @@ public class JdbcRegisteredClientRepositoryTests { public void findByClientIdWhenNullThenThrowIllegalArgumentException() { // @formatter:off assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.findByClientId(null)) + .isThrownBy(() -> this.registeredClientRepository.findByClientId(null)) .withMessage("clientId cannot be empty"); // @formatter:on } @@ -156,58 +160,58 @@ public class JdbcRegisteredClientRepositoryTests { @Test public void saveWhenNullThenThrowIllegalArgumentException() { assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.save(null)) + .isThrownBy(() -> this.registeredClientRepository.save(null)) .withMessageContaining("registeredClient cannot be null"); } @Test public void saveWhenExistingIdThenThrowIllegalArgumentException() { RegisteredClient registeredClient = createRegisteredClient( - this.registration.getId(), "client-id-2", "client-secret-2"); + this.registeredClient.getId(), "client-id-2", "client-secret-2"); assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.save(registeredClient)) + .isThrownBy(() -> this.registeredClientRepository.save(registeredClient)) .withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId()); } @Test public void saveWhenExistingClientIdThenThrowIllegalArgumentException() { RegisteredClient registeredClient = createRegisteredClient( - "client-2", this.registration.getClientId(), "client-secret-2"); + "client-2", this.registeredClient.getClientId(), "client-secret-2"); assertThatIllegalArgumentException() - .isThrownBy(() -> this.clients.save(registeredClient)) + .isThrownBy(() -> this.registeredClientRepository.save(registeredClient)) .withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId()); } @Test public void saveWhenExistingClientSecretThenSuccess() { RegisteredClient registeredClient = createRegisteredClient( - "client-2", "client-id-2", this.registration.getClientSecret()); - this.clients.save(registeredClient); - RegisteredClient savedClient = this.clients.findById(registeredClient.getId()); + "client-2", "client-id-2", this.registeredClient.getClientSecret()); + this.registeredClientRepository.save(registeredClient); + RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient); } @Test public void saveWhenSavedAndFindByIdThenFound() { RegisteredClient registeredClient = createRegisteredClient(); - this.clients.save(registeredClient); - RegisteredClient savedClient = this.clients.findById(registeredClient.getId()); + this.registeredClientRepository.save(registeredClient); + RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient); } @Test public void saveWhenSavedAndFindByClientIdThenFound() { RegisteredClient registeredClient = createRegisteredClient(); - this.clients.save(registeredClient); - RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId()); + this.registeredClientRepository.save(registeredClient); + RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient); } @Test public void saveWhenPublicClientSavedAndFindByClientIdThenFound() { RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build(); - this.clients.save(registeredClient); - RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId()); + this.registeredClientRepository.save(registeredClient); + RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient); } @@ -217,9 +221,9 @@ public class JdbcRegisteredClientRepositoryTests { .id("1").clientId("a").build(); RegisteredClient registeredClient2 = TestRegisteredClients.registeredPublicClient() .id("2").clientId("b").build(); - this.clients.save(registeredClient1); - this.clients.save(registeredClient2); - RegisteredClient savedClient = this.clients.findByClientId(registeredClient2.getClientId()); + this.registeredClientRepository.save(registeredClient1); + this.registeredClientRepository.save(registeredClient2); + RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient2.getClientId()); assertRegisteredClientIsEqualTo(savedClient, registeredClient2); } @@ -243,13 +247,28 @@ public class JdbcRegisteredClientRepositoryTests { }) .build(); - this.clients.save(client); + this.registeredClientRepository.save(client); - RegisteredClient retrievedClient = this.clients.findById(client.getId()); + RegisteredClient retrievedClient = this.registeredClientRepository.findById(client.getId()); assertRegisteredClientIsEqualTo(retrievedClient, client); } + @Test + public void tableDefinitionWhenCustomThenAbleToOverride() { + EmbeddedDatabase db = createDb(CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE); + CustomJdbcRegisteredClientRepository registeredClientRepository = + new CustomJdbcRegisteredClientRepository(new JdbcTemplate(db)); + registeredClientRepository.save(this.registeredClient); + RegisteredClient foundClient1 = registeredClientRepository.findById(this.registeredClient.getId()); + assertThat(foundClient1).isNotNull(); + assertRegisteredClientIsEqualTo(foundClient1, this.registeredClient); + RegisteredClient foundClient2 = registeredClientRepository.findByClientId(this.registeredClient.getClientId()); + assertThat(foundClient2).isNotNull(); + assertRegisteredClientIsEqualTo(foundClient2, this.registeredClient); + db.shutdown(); + } + private void assertRegisteredClientIsEqualTo(RegisteredClient rc, RegisteredClient ref) { assertThat(rc).isNotNull(); assertThat(rc.getId()).isEqualTo(ref.getId()); @@ -282,11 +301,21 @@ public class JdbcRegisteredClientRepositoryTests { assertThat(rc.getTokenSettings().refreshTokenTimeToLive()).isEqualTo(ref.getTokenSettings().refreshTokenTimeToLive()); } + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } + private static RegisteredClient createRegisteredClient() { return createRegisteredClient("client-2", "client-id-2", "client-secret-2"); } - private static RegisteredClient createRegisteredClient(String id, String clientId, String clientSecret) { // @formatter:off return RegisteredClient.withId(id) @@ -300,4 +329,125 @@ public class JdbcRegisteredClientRepositoryTests { // @formatter:on } + private static final class CustomJdbcRegisteredClientRepository extends JdbcRegisteredClientRepository { + + private static final String COLUMN_NAMES = "id, " + + "clientId, " + + "clientIdIssuedAt, " + + "clientSecret, " + + "clientSecretExpiresAt, " + + "clientName, " + + "clientAuthenticationMethods, " + + "authorizationGrantTypes, " + + "redirectUris, " + + "scopes, " + + "clientSettings," + + "tokenSettings"; + + private static final String TABLE_NAME = "oauth2RegisteredClient"; + + private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE "; + + private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"; + + public CustomJdbcRegisteredClientRepository(JdbcOperations jdbcOperations) { + super(jdbcOperations); + setRegisteredClientRowMapper(new CustomRegisteredClientRowMapper()); + } + + @Override + public void save(RegisteredClient registeredClient) { + List parameters = getRegisteredClientParametersMapper().apply(registeredClient); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + getJdbcOperations().update(INSERT_REGISTERED_CLIENT_SQL, pss); + } + + @Override + public RegisteredClient findById(String id) { + return findBy("id = ?", id); + } + + @Override + public RegisteredClient findByClientId(String clientId) { + return findBy("clientId = ?", clientId); + } + + private RegisteredClient findBy(String filter, Object... args) { + List result = getJdbcOperations() + .query(LOAD_REGISTERED_CLIENT_SQL + filter, getRegisteredClientRowMapper(), args); + return !result.isEmpty() ? result.get(0) : null; + } + + private static final class CustomRegisteredClientRowMapper implements RowMapper { + + private static final Map AUTHORIZATION_GRANT_TYPE_MAP; + private static final Map CLIENT_AUTHENTICATION_METHOD_MAP; + + private final ObjectMapper objectMapper = new ObjectMapper(); + + @Override + public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException { + Set clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes")); + Set authGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorizationGrantTypes")); + Set clientAuthMethods = StringUtils.commaDelimitedListToSet(rs.getString("clientAuthenticationMethods")); + Set redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirectUris")); + Timestamp clientIssuedAt = rs.getTimestamp("clientIdIssuedAt"); + Timestamp clientSecretExpiresAt = rs.getTimestamp("clientSecretExpiresAt"); + String clientSecret = rs.getString("clientSecret"); + RegisteredClient.Builder builder = RegisteredClient + .withId(rs.getString("id")) + .clientId(rs.getString("clientId")) + .clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null) + .clientSecret(clientSecret) + .clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null) + .clientName(rs.getString("clientName")) + .authorizationGrantTypes((grantTypes) -> authGrantTypes.forEach(authGrantType -> + grantTypes.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType)))) + .clientAuthenticationMethods((authenticationMethods) -> clientAuthMethods.forEach(clientAuthMethod -> + authenticationMethods.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod)))) + .redirectUris((uris) -> uris.addAll(redirectUris)) + .scopes((scopes) -> scopes.addAll(clientScopes)); + + RegisteredClient registeredClient = builder.build(); + registeredClient.getClientSettings().settings().putAll(parseMap(rs.getString("clientSettings"))); + registeredClient.getTokenSettings().settings().putAll(parseMap(rs.getString("tokenSettings"))); + + return registeredClient; + } + + private Map parseMap(String data) { + try { + return this.objectMapper.readValue(data, new TypeReference>() {}); + } catch (Exception ex) { + throw new IllegalArgumentException(ex.getMessage(), ex); + } + } + + static { + Map am = new HashMap<>(); + for (AuthorizationGrantType a : Arrays.asList( + AuthorizationGrantType.AUTHORIZATION_CODE, + AuthorizationGrantType.REFRESH_TOKEN, + AuthorizationGrantType.CLIENT_CREDENTIALS, + AuthorizationGrantType.PASSWORD, + AuthorizationGrantType.IMPLICIT)) { + am.put(a.getValue(), a); + } + AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am); + + Map cm = new HashMap<>(); + for (ClientAuthenticationMethod c : Arrays.asList( + ClientAuthenticationMethod.NONE, + ClientAuthenticationMethod.BASIC, + ClientAuthenticationMethod.POST)) { + cm.put(c.getValue(), c); + } + CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm); + } + + } + + } + } diff --git a/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql new file mode 100644 index 00000000..28e727f4 --- /dev/null +++ b/oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE oauth2RegisteredClient ( + id varchar(100) NOT NULL, + clientId varchar(100) NOT NULL, + clientIdIssuedAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL, + clientSecret varchar(200) DEFAULT NULL, + clientSecretExpiresAt timestamp DEFAULT NULL, + clientName varchar(200), + clientAuthenticationMethods varchar(1000) NOT NULL, + authorizationGrantTypes varchar(1000) NOT NULL, + redirectUris varchar(1000) NOT NULL, + scopes varchar(1000) NOT NULL, + clientSettings varchar(1000) DEFAULT NULL, + tokenSettings varchar(1000) DEFAULT NULL, + PRIMARY KEY (id) +);