Add test to override schema for JdbcRegisteredClientRepository

This commit is contained in:
Steve Riesenberg
2021-06-29 15:25:07 -05:00
parent 3318874da1
commit 623736d640
3 changed files with 227 additions and 61 deletions

View File

@@ -11,4 +11,5 @@ CREATE TABLE oauth2_registered_client (
scopes varchar(1000) NOT NULL,
client_settings varchar(1000) DEFAULT NULL,
token_settings varchar(1000) DEFAULT NULL,
PRIMARY KEY (id));
PRIMARY KEY (id)
);

View File

@@ -15,20 +15,37 @@
*/
package org.springframework.security.oauth2.server.authorization.client;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.datasource.DriverManagerDataSource;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.util.StreamUtils;
import org.springframework.util.StringUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -37,51 +54,38 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* JDBC-backed registered client repository tests
*
* @author Rafal Lewczuk
* @author Steve Riesenberg
* @since 0.1.2
*/
public class JdbcRegisteredClientRepositoryTests {
private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql";
private static final String REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql";
private static final String CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql";
private DriverManagerDataSource dataSource;
private JdbcRegisteredClientRepository clients;
private JdbcRegisteredClientRepository registeredClientRepository;
private RegisteredClient registration;
private RegisteredClient registeredClient;
private JdbcTemplate jdbc;
private EmbeddedDatabase db;
private JdbcOperations jdbcOperations;
@Before
public void setup() throws Exception {
this.dataSource = new DriverManagerDataSource();
this.dataSource.setDriverClassName("org.hsqldb.jdbcDriver");
this.dataSource.setUrl("jdbc:hsqldb:mem:oauthtest");
this.dataSource.setUsername("sa");
this.dataSource.setPassword("");
this.db = createDb(REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE);
this.jdbcOperations = new JdbcTemplate(this.db);
this.jdbc = new JdbcTemplate(this.dataSource);
this.registeredClientRepository = new JdbcRegisteredClientRepository(this.jdbcOperations);
this.registeredClient = TestRegisteredClients.registeredClient().build();
// execute scripts
try (InputStream is = JdbcRegisteredClientRepositoryTests.class.getResourceAsStream(SCRIPT)) {
assertThat(is).isNotNull().describedAs("Cannot open resource file: " + SCRIPT);
String ddls = StreamUtils.copyToString(is, Charset.defaultCharset());
for (String ddl : ddls.split(";\n")) {
if (!ddl.trim().isEmpty()) {
this.jdbc.execute(ddl.trim());
}
}
}
this.clients = new JdbcRegisteredClientRepository(this.jdbc);
this.registration = TestRegisteredClients.registeredClient().build();
this.clients.save(this.registration);
this.registeredClientRepository.save(this.registeredClient);
}
@After
public void destroyDatabase() {
this.jdbc.update("truncate table oauth2_registered_client");
new JdbcTemplate(this.dataSource).execute("SHUTDOWN");
this.db.shutdown();
}
@Test
@@ -97,7 +101,7 @@ public class JdbcRegisteredClientRepositoryTests {
public void whenSetNullRegisteredClientRowMapperThenThrow() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.setRegisteredClientRowMapper(null))
.isThrownBy(() -> this.registeredClientRepository.setRegisteredClientRowMapper(null))
.withMessage("registeredClientRowMapper cannot be null");
// @formatter:on
}
@@ -106,20 +110,20 @@ public class JdbcRegisteredClientRepositoryTests {
public void whenSetNullRegisteredClientParameterMapperThenThrow() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.setRegisteredClientParametersMapper(null))
.isThrownBy(() -> this.registeredClientRepository.setRegisteredClientParametersMapper(null))
.withMessage("registeredClientParameterMapper cannot be null");
// @formatter:on
}
@Test
public void findByIdWhenFoundThenFound() {
String id = this.registration.getId();
assertRegisteredClientIsEqualTo(this.clients.findById(id), this.registration);
String id = this.registeredClient.getId();
assertRegisteredClientIsEqualTo(this.registeredClientRepository.findById(id), this.registeredClient);
}
@Test
public void findByIdWhenNotFoundThenNull() {
RegisteredClient client = this.clients.findById("noooope");
RegisteredClient client = this.registeredClientRepository.findById("noooope");
assertThat(client).isNull();
}
@@ -127,20 +131,20 @@ public class JdbcRegisteredClientRepositoryTests {
public void findByIdWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.findById(null))
.isThrownBy(() -> this.registeredClientRepository.findById(null))
.withMessage("id cannot be empty");
// @formatter:on
}
@Test
public void findByClientIdWhenFoundThenFound() {
String id = this.registration.getClientId();
assertRegisteredClientIsEqualTo(this.clients.findByClientId(id), this.registration);
String id = this.registeredClient.getClientId();
assertRegisteredClientIsEqualTo(this.registeredClientRepository.findByClientId(id), this.registeredClient);
}
@Test
public void findByClientIdWhenNotFoundThenNull() {
RegisteredClient client = this.clients.findByClientId("noooope");
RegisteredClient client = this.registeredClientRepository.findByClientId("noooope");
assertThat(client).isNull();
}
@@ -148,7 +152,7 @@ public class JdbcRegisteredClientRepositoryTests {
public void findByClientIdWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.findByClientId(null))
.isThrownBy(() -> this.registeredClientRepository.findByClientId(null))
.withMessage("clientId cannot be empty");
// @formatter:on
}
@@ -156,58 +160,58 @@ public class JdbcRegisteredClientRepositoryTests {
@Test
public void saveWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.save(null))
.isThrownBy(() -> this.registeredClientRepository.save(null))
.withMessageContaining("registeredClient cannot be null");
}
@Test
public void saveWhenExistingIdThenThrowIllegalArgumentException() {
RegisteredClient registeredClient = createRegisteredClient(
this.registration.getId(), "client-id-2", "client-secret-2");
this.registeredClient.getId(), "client-id-2", "client-secret-2");
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.save(registeredClient))
.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
.withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
}
@Test
public void saveWhenExistingClientIdThenThrowIllegalArgumentException() {
RegisteredClient registeredClient = createRegisteredClient(
"client-2", this.registration.getClientId(), "client-secret-2");
"client-2", this.registeredClient.getClientId(), "client-secret-2");
assertThatIllegalArgumentException()
.isThrownBy(() -> this.clients.save(registeredClient))
.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
.withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
}
@Test
public void saveWhenExistingClientSecretThenSuccess() {
RegisteredClient registeredClient = createRegisteredClient(
"client-2", "client-id-2", this.registration.getClientSecret());
this.clients.save(registeredClient);
RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
"client-2", "client-id-2", this.registeredClient.getClientSecret());
this.registeredClientRepository.save(registeredClient);
RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient);
}
@Test
public void saveWhenSavedAndFindByIdThenFound() {
RegisteredClient registeredClient = createRegisteredClient();
this.clients.save(registeredClient);
RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
this.registeredClientRepository.save(registeredClient);
RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient);
}
@Test
public void saveWhenSavedAndFindByClientIdThenFound() {
RegisteredClient registeredClient = createRegisteredClient();
this.clients.save(registeredClient);
RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
this.registeredClientRepository.save(registeredClient);
RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient);
}
@Test
public void saveWhenPublicClientSavedAndFindByClientIdThenFound() {
RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
this.clients.save(registeredClient);
RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
this.registeredClientRepository.save(registeredClient);
RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient);
}
@@ -217,9 +221,9 @@ public class JdbcRegisteredClientRepositoryTests {
.id("1").clientId("a").build();
RegisteredClient registeredClient2 = TestRegisteredClients.registeredPublicClient()
.id("2").clientId("b").build();
this.clients.save(registeredClient1);
this.clients.save(registeredClient2);
RegisteredClient savedClient = this.clients.findByClientId(registeredClient2.getClientId());
this.registeredClientRepository.save(registeredClient1);
this.registeredClientRepository.save(registeredClient2);
RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient2.getClientId());
assertRegisteredClientIsEqualTo(savedClient, registeredClient2);
}
@@ -243,13 +247,28 @@ public class JdbcRegisteredClientRepositoryTests {
})
.build();
this.clients.save(client);
this.registeredClientRepository.save(client);
RegisteredClient retrievedClient = this.clients.findById(client.getId());
RegisteredClient retrievedClient = this.registeredClientRepository.findById(client.getId());
assertRegisteredClientIsEqualTo(retrievedClient, client);
}
@Test
public void tableDefinitionWhenCustomThenAbleToOverride() {
EmbeddedDatabase db = createDb(CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE);
CustomJdbcRegisteredClientRepository registeredClientRepository =
new CustomJdbcRegisteredClientRepository(new JdbcTemplate(db));
registeredClientRepository.save(this.registeredClient);
RegisteredClient foundClient1 = registeredClientRepository.findById(this.registeredClient.getId());
assertThat(foundClient1).isNotNull();
assertRegisteredClientIsEqualTo(foundClient1, this.registeredClient);
RegisteredClient foundClient2 = registeredClientRepository.findByClientId(this.registeredClient.getClientId());
assertThat(foundClient2).isNotNull();
assertRegisteredClientIsEqualTo(foundClient2, this.registeredClient);
db.shutdown();
}
private void assertRegisteredClientIsEqualTo(RegisteredClient rc, RegisteredClient ref) {
assertThat(rc).isNotNull();
assertThat(rc.getId()).isEqualTo(ref.getId());
@@ -282,11 +301,21 @@ public class JdbcRegisteredClientRepositoryTests {
assertThat(rc.getTokenSettings().refreshTokenTimeToLive()).isEqualTo(ref.getTokenSettings().refreshTokenTimeToLive());
}
private static EmbeddedDatabase createDb(String schema) {
// @formatter:off
return new EmbeddedDatabaseBuilder()
.generateUniqueName(true)
.setType(EmbeddedDatabaseType.HSQL)
.setScriptEncoding("UTF-8")
.addScript(schema)
.build();
// @formatter:on
}
private static RegisteredClient createRegisteredClient() {
return createRegisteredClient("client-2", "client-id-2", "client-secret-2");
}
private static RegisteredClient createRegisteredClient(String id, String clientId, String clientSecret) {
// @formatter:off
return RegisteredClient.withId(id)
@@ -300,4 +329,125 @@ public class JdbcRegisteredClientRepositoryTests {
// @formatter:on
}
private static final class CustomJdbcRegisteredClientRepository extends JdbcRegisteredClientRepository {
private static final String COLUMN_NAMES = "id, "
+ "clientId, "
+ "clientIdIssuedAt, "
+ "clientSecret, "
+ "clientSecretExpiresAt, "
+ "clientName, "
+ "clientAuthenticationMethods, "
+ "authorizationGrantTypes, "
+ "redirectUris, "
+ "scopes, "
+ "clientSettings,"
+ "tokenSettings";
private static final String TABLE_NAME = "oauth2RegisteredClient";
private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE ";
private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
public CustomJdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
super(jdbcOperations);
setRegisteredClientRowMapper(new CustomRegisteredClientRowMapper());
}
@Override
public void save(RegisteredClient registeredClient) {
List<SqlParameterValue> parameters = getRegisteredClientParametersMapper().apply(registeredClient);
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
getJdbcOperations().update(INSERT_REGISTERED_CLIENT_SQL, pss);
}
@Override
public RegisteredClient findById(String id) {
return findBy("id = ?", id);
}
@Override
public RegisteredClient findByClientId(String clientId) {
return findBy("clientId = ?", clientId);
}
private RegisteredClient findBy(String filter, Object... args) {
List<RegisteredClient> result = getJdbcOperations()
.query(LOAD_REGISTERED_CLIENT_SQL + filter, getRegisteredClientRowMapper(), args);
return !result.isEmpty() ? result.get(0) : null;
}
private static final class CustomRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
private static final Map<String, AuthorizationGrantType> AUTHORIZATION_GRANT_TYPE_MAP;
private static final Map<String, ClientAuthenticationMethod> CLIENT_AUTHENTICATION_METHOD_MAP;
private final ObjectMapper objectMapper = new ObjectMapper();
@Override
public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
Set<String> clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes"));
Set<String> authGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorizationGrantTypes"));
Set<String> clientAuthMethods = StringUtils.commaDelimitedListToSet(rs.getString("clientAuthenticationMethods"));
Set<String> redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirectUris"));
Timestamp clientIssuedAt = rs.getTimestamp("clientIdIssuedAt");
Timestamp clientSecretExpiresAt = rs.getTimestamp("clientSecretExpiresAt");
String clientSecret = rs.getString("clientSecret");
RegisteredClient.Builder builder = RegisteredClient
.withId(rs.getString("id"))
.clientId(rs.getString("clientId"))
.clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null)
.clientSecret(clientSecret)
.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
.clientName(rs.getString("clientName"))
.authorizationGrantTypes((grantTypes) -> authGrantTypes.forEach(authGrantType ->
grantTypes.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
.clientAuthenticationMethods((authenticationMethods) -> clientAuthMethods.forEach(clientAuthMethod ->
authenticationMethods.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
.redirectUris((uris) -> uris.addAll(redirectUris))
.scopes((scopes) -> scopes.addAll(clientScopes));
RegisteredClient registeredClient = builder.build();
registeredClient.getClientSettings().settings().putAll(parseMap(rs.getString("clientSettings")));
registeredClient.getTokenSettings().settings().putAll(parseMap(rs.getString("tokenSettings")));
return registeredClient;
}
private Map<String, Object> parseMap(String data) {
try {
return this.objectMapper.readValue(data, new TypeReference<Map<String, Object>>() {});
} catch (Exception ex) {
throw new IllegalArgumentException(ex.getMessage(), ex);
}
}
static {
Map<String, AuthorizationGrantType> am = new HashMap<>();
for (AuthorizationGrantType a : Arrays.asList(
AuthorizationGrantType.AUTHORIZATION_CODE,
AuthorizationGrantType.REFRESH_TOKEN,
AuthorizationGrantType.CLIENT_CREDENTIALS,
AuthorizationGrantType.PASSWORD,
AuthorizationGrantType.IMPLICIT)) {
am.put(a.getValue(), a);
}
AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am);
Map<String, ClientAuthenticationMethod> cm = new HashMap<>();
for (ClientAuthenticationMethod c : Arrays.asList(
ClientAuthenticationMethod.NONE,
ClientAuthenticationMethod.BASIC,
ClientAuthenticationMethod.POST)) {
cm.put(c.getValue(), c);
}
CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm);
}
}
}
}

View File

@@ -0,0 +1,15 @@
CREATE TABLE oauth2RegisteredClient (
id varchar(100) NOT NULL,
clientId varchar(100) NOT NULL,
clientIdIssuedAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
clientSecret varchar(200) DEFAULT NULL,
clientSecretExpiresAt timestamp DEFAULT NULL,
clientName varchar(200),
clientAuthenticationMethods varchar(1000) NOT NULL,
authorizationGrantTypes varchar(1000) NOT NULL,
redirectUris varchar(1000) NOT NULL,
scopes varchar(1000) NOT NULL,
clientSettings varchar(1000) DEFAULT NULL,
tokenSettings varchar(1000) DEFAULT NULL,
PRIMARY KEY (id)
);