diff --git a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle index 3bede36f..6f053db2 100644 --- a/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle +++ b/oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle @@ -16,7 +16,6 @@ dependencies { testCompile 'org.assertj:assertj-core' testCompile 'org.mockito:mockito-core' testCompile 'com.jayway.jsonpath:json-path' - testCompile 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310' provided 'javax.servlet:javax.servlet-api' } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerConfiguration.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerConfiguration.java index 8cdead39..8341a66e 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerConfiguration.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizationServerConfiguration.java @@ -15,6 +15,17 @@ */ package org.springframework.security.config.annotation.web.configuration; +import java.util.HashSet; +import java.util.Set; + +import com.nimbusds.jose.JWSAlgorithm; +import com.nimbusds.jose.jwk.source.JWKSource; +import com.nimbusds.jose.proc.JWSKeySelector; +import com.nimbusds.jose.proc.JWSVerificationKeySelector; +import com.nimbusds.jose.proc.SecurityContext; +import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; +import com.nimbusds.jwt.proc.DefaultJWTProcessor; + import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.core.Ordered; @@ -22,6 +33,8 @@ import org.springframework.core.annotation.Order; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization.OAuth2AuthorizationServerConfigurer; import org.springframework.security.config.annotation.web.configurers.oauth2.server.resource.OAuth2ResourceServerConfigurer; +import org.springframework.security.oauth2.jwt.JwtDecoder; +import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -29,8 +42,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher; * {@link Configuration} for OAuth 2.0 Authorization Server support. * * @author Joe Grandja - * @see OAuth2AuthorizationServerConfigurer * @since 0.0.1 + * @see OAuth2AuthorizationServerConfigurer */ @Configuration(proxyBeanMethods = false) public class OAuth2AuthorizationServerConfiguration { @@ -48,16 +61,32 @@ public class OAuth2AuthorizationServerConfiguration { new OAuth2AuthorizationServerConfigurer<>(); RequestMatcher endpointsMatcher = authorizationServerConfigurer .getEndpointsMatcher(); + http .requestMatcher(endpointsMatcher) .authorizeRequests(authorizeRequests -> - authorizeRequests.anyRequest().authenticated() - ).csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + authorizeRequests.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) .apply(authorizationServerConfigurer); - - if (authorizationServerConfigurer.isOidcClientRegistrationEnabled()) { - http.oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt); - } } // @formatter:on + + @Bean + public static JwtDecoder jwtDecoder(JWKSource jwkSource) { + Set jwsAlgs = new HashSet<>(); + jwsAlgs.addAll(JWSAlgorithm.Family.RSA); + jwsAlgs.addAll(JWSAlgorithm.Family.EC); + jwsAlgs.addAll(JWSAlgorithm.Family.HMAC_SHA); + ConfigurableJWTProcessor jwtProcessor = new DefaultJWTProcessor<>(); + JWSKeySelector jwsKeySelector = + new JWSVerificationKeySelector<>(jwsAlgs, jwkSource); + jwtProcessor.setJWSKeySelector(jwsKeySelector); + // Override the default Nimbus claims set verifier as NimbusJwtDecoder handles it instead + jwtProcessor.setJWTClaimsSetVerifier((claims, context) -> { + }); + return new NimbusJwtDecoder(jwtProcessor); + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java index d11f285f..a3b3c584 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java @@ -44,9 +44,9 @@ import org.springframework.security.oauth2.server.authorization.authentication.O import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider; -import org.springframework.security.oauth2.server.authorization.authentication.OidcClientRegistrationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcProviderConfigurationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.NimbusJwkSetEndpointFilter; @@ -152,17 +152,6 @@ public final class OAuth2AuthorizationServerConfigurer exceptionHandling = builder.getConfigurer(ExceptionHandlingConfigurer.class); if (exceptionHandling != null) { @@ -246,9 +237,6 @@ public final class OAuth2AuthorizationServerConfigurer jwkSource = getJwkSource(builder); NimbusJwkSetEndpointFilter jwkSetEndpointFilter = new NimbusJwkSetEndpointFilter( jwkSource, @@ -268,8 +256,8 @@ public final class OAuth2AuthorizationServerConfigurer getRedirectUris() { - return getClaimAsStringList(OidcClientMetadataClaimNames.REDIRECT_URIS); - } - - /** - * Returns the OAuth 2.0 {@code response_type} values that the client may use. - * - * @return the {@code List} of {@code response_type} - */ - default List getResponseTypes() { - return getClaimAsStringList(OidcClientMetadataClaimNames.RESPONSE_TYPES); - } - - /** - * Returns the authorization {@code grant_types} that the client may use. - * - * @return the {@code List} of authorization {@code grant_types} - */ - default List getGrantTypes() { - return getClaimAsStringList(OidcClientMetadataClaimNames.GRANT_TYPES); - } - - /** - * Returns the {@code client_name}. - * - * @return the {@code client_name} - */ - default String getClientName() { - return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_NAME); - } - - /** - * Returns the scope(s) that the client may use. - * - * @return the scope(s) - */ - default String getScope() { - return getClaimAsString(OidcClientMetadataClaimNames.SCOPE); - } - - /** - * Returns the {@link ClientAuthenticationMethod authentication method} that the client may use. - * - * @return the {@link ClientAuthenticationMethod authentication method} - */ - default String getTokenEndpointAuthenticationMethod() { - return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); - } - - /** - * Returns the {@code client_id}. - * - * @return the {@code client_id} + * @return the Client Identifier */ default String getClientId() { return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_ID); } /** - * Returns the {@code client_id_issued_at} timestamp. + * Returns the time at which the Client Identifier was issued {@code (client_id_issued_at)}. * - * @return the {@code client_id_issued_at} timestamp + * @return the time at which the Client Identifier was issued */ default Instant getClientIdIssuedAt() { return getClaimAsInstant(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT); } /** - * Returns the {@code client_secret}. + * Returns the Client Secret {@code (client_secret)}. * - * @return the {@code client_secret} + * @return the Client Secret */ default String getClientSecret() { return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_SECRET); } /** - * Returns the {@code client_secret_expires_at} timestamp. + * Returns the time at which the {@code client_secret} will expire {@code (client_secret_expires_at)}. * - * @return the {@code client_secret_expires_at} timestamp + * @return the time at which the {@code client_secret} will expire */ default Instant getClientSecretExpiresAt() { return getClaimAsInstant(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); } + /** + * Returns the name of the Client to be presented to the End-User {@code (client_name)}. + * + * @return the name of the Client to be presented to the End-User + */ + default String getClientName() { + return getClaimAsString(OidcClientMetadataClaimNames.CLIENT_NAME); + } + /** + * Returns the redirection {@code URI} values used by the Client {@code (redirect_uris)}. + * + * @return the redirection {@code URI} values used by the Client + */ + default List getRedirectUris() { + return getClaimAsStringList(OidcClientMetadataClaimNames.REDIRECT_URIS); + } + /** + * Returns the authentication method used by the Client for the Token Endpoint {@code (token_endpoint_auth_method)}. + * + * @return the authentication method used by the Client for the Token Endpoint + */ + default String getTokenEndpointAuthenticationMethod() { + return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); + } + + /** + * Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using {@code (grant_types)}. + * + * @return the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using + */ + default List getGrantTypes() { + return getClaimAsStringList(OidcClientMetadataClaimNames.GRANT_TYPES); + } + + /** + * Returns the OAuth 2.0 {@code response_type} values that the Client will restrict itself to using {@code (response_types)}. + * + * @return the OAuth 2.0 {@code response_type} values that the Client will restrict itself to using + */ + default List getResponseTypes() { + return getClaimAsStringList(OidcClientMetadataClaimNames.RESPONSE_TYPES); + } + + /** + * Returns the OAuth 2.0 {@code scope} values that the Client will restrict itself to using {@code (scope)}. + * + * @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to using + */ + default List getScopes() { + return getClaimAsStringList(OidcClientMetadataClaimNames.SCOPE); + } + + /** + * Returns the {@link SignatureAlgorithm JWS} algorithm required for signing the {@link OidcIdToken ID Token} issued to the Client {@code (id_token_signed_response_alg)}. + * + * @return the {@link SignatureAlgorithm JWS} algorithm required for signing the {@link OidcIdToken ID Token} issued to the Client + */ + default String getIdTokenSignedResponseAlgorithm() { + return getClaimAsString(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG); + } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java index f18915e4..63a5d205 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java @@ -15,65 +15,72 @@ */ package org.springframework.security.oauth2.core.oidc; +import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; + /** - * The names of the "claims" defined by OpenID Client Registration 1.0 that can be returned - * in the OpenID Client Registration Response. + * The names of the "claims" defined by OpenID Connect Dynamic Client Registration 1.0 + * that are contained in the OpenID Client Registration Request and Response. * * @author Ovidiu Popa + * @author Joe Grandja * @since 0.1.1 * @see 2. Client Metadata */ public interface OidcClientMetadataClaimNames { - //request /** - * {@code redirect_uris} - the redirect URI(s) that the client may use in redirect-based flows - */ - String REDIRECT_URIS = "redirect_uris"; - - /** - * {@code response_types} - the OAuth 2.0 {@code response_type} values that the client may use - */ - String RESPONSE_TYPES = "response_types"; - - /** - * {@code grant_types} - the OAuth 2.0 authorization {@code grant_types} that the client may use - */ - String GRANT_TYPES = "grant_types"; - - /** - * {@code client_name} - the {@code client_name} - */ - String CLIENT_NAME = "client_name"; - - /** - * {@code scope} - the scope(s) that the client may use - */ - String SCOPE = "scope"; - - /** - * {@code token_endpoint_auth_method} - the {@link org.springframework.security.oauth2.core.ClientAuthenticationMethod authentication method} that the client may use. - */ - String TOKEN_ENDPOINT_AUTH_METHOD = "token_endpoint_auth_method"; - - //response - /** - * {@code client_id} - the {@code client_id} + * {@code client_id} - the Client Identifier */ String CLIENT_ID = "client_id"; /** - * {@code client_secret} - the {@code client_secret} - */ - String CLIENT_SECRET = "client_secret"; - - /** - * {@code client_id_issued_at} - the timestamp when the client id was issued + * {@code client_id_issued_at} - the time at which the Client Identifier was issued */ String CLIENT_ID_ISSUED_AT = "client_id_issued_at"; /** - * {@code client_secret_expires_at} - the timestamp when the client secret expires + * {@code client_secret} - the Client Secret + */ + String CLIENT_SECRET = "client_secret"; + + /** + * {@code client_secret_expires_at} - the time at which the {@code client_secret} will expire or 0 if it will not expire */ String CLIENT_SECRET_EXPIRES_AT = "client_secret_expires_at"; + + /** + * {@code client_name} - the name of the Client to be presented to the End-User + */ + String CLIENT_NAME = "client_name"; + + /** + * {@code redirect_uris} - the redirection {@code URI} values used by the Client + */ + String REDIRECT_URIS = "redirect_uris"; + + /** + * {@code token_endpoint_auth_method} - the authentication method used by the Client for the Token Endpoint + */ + String TOKEN_ENDPOINT_AUTH_METHOD = "token_endpoint_auth_method"; + + /** + * {@code grant_types} - the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using + */ + String GRANT_TYPES = "grant_types"; + + /** + * {@code response_types} - the OAuth 2.0 {@code response_type} values that the Client will restrict itself to using + */ + String RESPONSE_TYPES = "response_types"; + + /** + * {@code scope} - a space-separated list of OAuth 2.0 {@code scope} values that the Client will restrict itself to using + */ + String SCOPE = "scope"; + + /** + * {@code id_token_signed_response_alg} - the {@link JwsAlgorithm JWS} algorithm required for signing the {@link OidcIdToken ID Token} issued to the Client + */ + String ID_TOKEN_SIGNED_RESPONSE_ALG = "id_token_signed_response_alg"; + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java index fbdccf02..e098f262 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java @@ -15,12 +15,6 @@ */ package org.springframework.security.oauth2.core.oidc; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.Version; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; -import org.springframework.util.Assert; - import java.io.Serializable; import java.net.URI; import java.net.URL; @@ -32,28 +26,36 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import org.springframework.security.oauth2.core.Version; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.util.Assert; + /** * A representation of an OpenID Client Registration Request and Response, - * which contains a set of claims defined by the - * OpenID Connect Registration 1.0 specification. + * which is sent to and returned from the Client Registration Endpoint, + * and contains a set of claims about the Client's Registration information. + * The claims are defined by the OpenID Connect Dynamic Client Registration 1.0 specification. * * @author Ovidiu Popa + * @author Joe Grandja * @since 0.1.1 * @see OidcClientMetadataClaimAccessor - * @see 3.1. Client Registration Request + * @see 3.1. Client Registration Request + * @see 3.2. Client Registration Response */ public final class OidcClientRegistration implements OidcClientMetadataClaimAccessor, Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private final Map claims; private OidcClientRegistration(Map claims) { - this.claims = Collections.unmodifiableMap(claims); + Assert.notEmpty(claims, "claims cannot be empty"); + this.claims = Collections.unmodifiableMap(new LinkedHashMap<>(claims)); } /** - * Returns the OpenID Client Registration metadata. + * Returns the metadata as claims. * - * @return a {@code Map} of the metadata values + * @return a {@code Map} of the metadata as claims */ @Override public Map getClaims() { @@ -61,9 +63,9 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * Constructs a new {@link OidcClientRegistration.Builder} with empty claims. + * Constructs a new {@link Builder} with empty claims. * - * @return the {@link OidcClientRegistration.Builder} + * @return the {@link Builder} */ public static Builder builder() { return new Builder(); @@ -80,18 +82,69 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce .claims(c -> c.putAll(claims)); } + /** + * Helps configure an {@link OidcClientRegistration}. + */ public static class Builder { - private final Map claims = new LinkedHashMap<>(); private Builder() { } /** - * Add this Redirect URI to the collection of {@code redirect_uris} in the resulting - * {@link OidcClientRegistration}, REQUIRED. + * Sets the Client Identifier, REQUIRED. * - * @param redirectUri the OAuth 2.0 {@code redirect_uri} value that client supports + * @param clientId the Client Identifier + * @return the {@link Builder} for further configuration + */ + public Builder clientId(String clientId) { + return claim(OidcClientMetadataClaimNames.CLIENT_ID, clientId); + } + + /** + * Sets the time at which the Client Identifier was issued, OPTIONAL. + * + * @param clientIdIssuedAt the time at which the Client Identifier was issued + * @return the {@link Builder} for further configuration + */ + public Builder clientIdIssuedAt(Instant clientIdIssuedAt) { + return claim(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIdIssuedAt); + } + + /** + * Sets the Client Secret, OPTIONAL. + * + * @param clientSecret the Client Secret + * @return the {@link Builder} for further configuration + */ + public Builder clientSecret(String clientSecret) { + return claim(OidcClientMetadataClaimNames.CLIENT_SECRET, clientSecret); + } + + /** + * Sets the time at which the {@code client_secret} will expire or {@code null} if it will not expire, REQUIRED if {@code client_secret} was issued. + * + * @param clientSecretExpiresAt the time at which the {@code client_secret} will expire or {@code null} if it will not expire + * @return the {@link Builder} for further configuration + */ + public Builder clientSecretExpiresAt(Instant clientSecretExpiresAt) { + return claim(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + } + + /** + * Sets the name of the Client to be presented to the End-User, OPTIONAL. + * + * @param clientName the name of the Client to be presented to the End-User + * @return the {@link Builder} for further configuration + */ + public Builder clientName(String clientName) { + return claim(OidcClientMetadataClaimNames.CLIENT_NAME, clientName); + } + + /** + * Add the redirection {@code URI} used by the Client, REQUIRED. + * + * @param redirectUri the redirection {@code URI} used by the Client * @return the {@link Builder} for further configuration */ public Builder redirectUri(String redirectUri) { @@ -100,100 +153,31 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * A {@code Consumer} of the Redirect URI(s) allowing the ability to add, replace, or remove. + * A {@code Consumer} of the redirection {@code URI} values used by the Client, + * allowing the ability to add, replace, or remove, REQUIRED. * - * @param redirectUriConsumer a {@code Consumer} of the Redirect URI(s) + * @param redirectUrisConsumer a {@code Consumer} of the redirection {@code URI} values used by the Client * @return the {@link Builder} for further configuration */ - public Builder redirectUris(Consumer> redirectUriConsumer) { - acceptClaimValues(OidcClientMetadataClaimNames.REDIRECT_URIS, redirectUriConsumer); + public Builder redirectUris(Consumer> redirectUrisConsumer) { + acceptClaimValues(OidcClientMetadataClaimNames.REDIRECT_URIS, redirectUrisConsumer); return this; } /** - * Add this Response Type to the collection of {@code response_types} in the resulting - * {@link OidcClientRegistration}, OPTIONAL. + * Sets the authentication method used by the Client for the Token Endpoint, OPTIONAL. * - * @param responseType the OAuth 2.0 {@code response_type} value that client supports + * @param tokenEndpointAuthenticationMethod the authentication method used by the Client for the Token Endpoint * @return the {@link Builder} for further configuration */ - public Builder responseType(String responseType) { - addClaimToClaimList(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseType); - return this; + public Builder tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticationMethod) { + return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod); } /** - * Add {@code Consumer} of {@code response_types} allowing the ability to add, replace, or remove - * {@link OidcClientRegistration}, OPTIONAL. + * Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to using, OPTIONAL. * - * @param responseType the OAuth 2.0 {@code response_type} value that client supports - * @return the {@link Builder} for further configuration - */ - public Builder responseTypes(Consumer> responseType) { - acceptClaimValues(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseType); - return this; - } - - /** - * Sets {@code client_name} claim in the resulting - * {@link OidcClientRegistration}, OPTIONAL. - * - * @param clientName the OAuth 2.0 {@code client_name} of the registered client - * @return the {@link Builder} for further configuration - */ - public Builder clientName(String clientName) { - return claim(OidcClientMetadataClaimNames.CLIENT_NAME, clientName); - } - - /** - * Sets {@code client_id} claim in the resulting - * {@link OidcClientRegistration}. - * - * @param clientId the OAuth 2.0 {@code client_id} of the registered client - * @return the {@link Builder} for further configuration - */ - public Builder clientId(String clientId) { - return claim(OidcClientMetadataClaimNames.CLIENT_ID, clientId); - } - - /** - * Sets {@code client_id_issued_at} claim in the resulting - * {@link OidcClientRegistration}. - * - * @param clientIssuedAt the timestamp {@code client_id_issued_at} when the client was issued - * @return the {@link Builder} for further configuration - */ - public Builder clientIdIssuedAt(Instant clientIssuedAt) { - return claim(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIssuedAt); - } - - /** - * Sets {@code client_secret} claim in the resulting - * {@link OidcClientRegistration}. - * - * @param clientSecret the {@code client_secret} of the registered client - * @return the {@link Builder} for further configuration - */ - public Builder clientSecret(String clientSecret) { - return claim(OidcClientMetadataClaimNames.CLIENT_SECRET, clientSecret); - } - - /** - * Sets {@code client_secret_expires_at} claim in the resulting - * {@link OidcClientRegistration}. - * - * @param clientSecretExpiresAt the timestamp {@code client_secret_expires_at} when the client_secret expires - * @return the {@link Builder} for further configuration - */ - public Builder clientSecretExpiresAt(Instant clientSecretExpiresAt) { - return claim(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); - } - - /** - * Add this Grant Type to the collection of {@code grant_types_supported} in the resulting - * {@link OidcClientRegistration}, OPTIONAL. - * - * @param grantType the OAuth 2.0 {@code grant_type} value that client supports + * @param grantType the OAuth 2.0 {@code grant_type} that the Client will restrict itself to using * @return the {@link Builder} for further configuration */ public Builder grantType(String grantType) { @@ -202,9 +186,10 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * A {@code Consumer} of the Grant Type(s) allowing the ability to add, replace, or remove. + * A {@code Consumer} of the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using, + * allowing the ability to add, replace, or remove, OPTIONAL. * - * @param grantTypesConsumer a {@code Consumer} of the Grant Type(s) + * @param grantTypesConsumer a {@code Consumer} of the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using * @return the {@link Builder} for further configuration */ public Builder grantTypes(Consumer> grantTypesConsumer) { @@ -213,22 +198,44 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * Add this Scope to the collection of {@code scopes_supported} in the resulting - * {@link OidcClientRegistration}, RECOMMENDED. + * Add the OAuth 2.0 {@code response_type} that the Client will restrict itself to using, OPTIONAL. * - * @param scope the OAuth 2.0 {@code scope} value that client supports + * @param responseType the OAuth 2.0 {@code response_type} that the Client will restrict itself to using * @return the {@link Builder} for further configuration */ - public Builder scope(String scope) { - claim(OidcClientMetadataClaimNames.SCOPE, scope); + public Builder responseType(String responseType) { + addClaimToClaimList(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseType); return this; } /** - * Add {@code Consumer} of {@code scopes} allowing the ability to add, replace, or remove - * {@link OidcClientRegistration}, RECOMMENDED. + * A {@code Consumer} of the OAuth 2.0 {@code response_type} values that the Client will restrict itself to using, + * allowing the ability to add, replace, or remove, OPTIONAL. * - * @param scopesConsumer the OAuth 2.0 {@code scope} value that client supports + * @param responseTypesConsumer a {@code Consumer} of the OAuth 2.0 {@code response_type} values that the Client will restrict itself to using + * @return the {@link Builder} for further configuration + */ + public Builder responseTypes(Consumer> responseTypesConsumer) { + acceptClaimValues(OidcClientMetadataClaimNames.RESPONSE_TYPES, responseTypesConsumer); + return this; + } + + /** + * Add the OAuth 2.0 {@code scope} that the Client will restrict itself to using, OPTIONAL. + * + * @param scope the OAuth 2.0 {@code scope} that the Client will restrict itself to using + * @return the {@link Builder} for further configuration + */ + public Builder scope(String scope) { + addClaimToClaimList(OidcClientMetadataClaimNames.SCOPE, scope); + return this; + } + + /** + * A {@code Consumer} of the OAuth 2.0 {@code scope} values that the Client will restrict itself to using, + * allowing the ability to add, replace, or remove, OPTIONAL. + * + * @param scopesConsumer a {@code Consumer} of the OAuth 2.0 {@code scope} values that the Client will restrict itself to using * @return the {@link Builder} for further configuration */ public Builder scopes(Consumer> scopesConsumer) { @@ -237,19 +244,17 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * Add this Token endpoint authentication method to the collection of {@code token_endpoint_auth_method} in the resulting - * {@link OidcClientRegistration}, OPTIONAL. + * Sets the {@link SignatureAlgorithm JWS} algorithm required for signing the {@link OidcIdToken ID Token} issued to the Client, OPTIONAL. * - * @param tokenEndpointAuthenticationMethod the OAuth 2.0 {@code token_endpoint_auth_method} value that client supports + * @param idTokenSignedResponseAlgorithm the {@link SignatureAlgorithm JWS} algorithm required for signing the {@link OidcIdToken ID Token} issued to the Client * @return the {@link Builder} for further configuration */ - public Builder tokenEndpointAuthenticationMethod(String tokenEndpointAuthenticationMethod) { - claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod); - return this; + public Builder idTokenSignedResponseAlgorithm(String idTokenSignedResponseAlgorithm) { + return claim(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG, idTokenSignedResponseAlgorithm); } /** - * Add this claim in the resulting {@link OidcClientRegistration}. + * Sets the claim. * * @param name the claim name * @param value the claim value @@ -263,8 +268,8 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce } /** - * Provides access to every {@link #claim(String, Object)} declared so far with - * the possibility to add, replace, or remove. + * Provides access to every {@link #claim(String, Object)} declared so far + * allowing the ability to add, replace, or remove. * * @param claimsConsumer a {@code Consumer} of the claims * @return the {@link Builder} for further configurations @@ -274,58 +279,48 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce return this; } + /** + * Validate the claims and build the {@link OidcClientRegistration}. + *

+ * The following claims are REQUIRED: + * {@code client_id}, {@code redirect_uris}. + * + * @return the {@link OidcClientRegistration} + */ public OidcClientRegistration build() { - this.claims.computeIfAbsent(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, - k -> ClientAuthenticationMethod.BASIC.getValue()); - // If omitted, the default is that the Client will use only the authorization_code Grant Type. - this.claims.computeIfAbsent(OidcClientMetadataClaimNames.GRANT_TYPES, - k -> Collections.singletonList(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())); - //If omitted, the default is that the Client will use only the code Response Type. - this.claims.computeIfAbsent(OidcClientMetadataClaimNames.RESPONSE_TYPES, - k -> Collections.singletonList(OAuth2AuthorizationResponseType.CODE.getValue())); - validateRedirectUris(); - validateReponseTypesClaim(); - validateGrantTypesClaim(); + validate(); return new OidcClientRegistration(this.claims); } - private void validateRedirectUris() { - // redirect_uris is required + private void validate() { + if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null || + this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET) != null) { + Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID), "client_id cannot be null"); + } + if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT) != null) { + Assert.isInstanceOf(Instant.class, this.claims.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT), "client_id_issued_at must be of type Instant"); + } + if (this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT) != null) { + Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET), "client_secret cannot be null"); + Assert.isInstanceOf(Instant.class, this.claims.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT), "client_secret_expires_at must be of type Instant"); + } Assert.notNull(this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be null"); - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris must be of type list"); - Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris must not be empty"); + Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris must be of type List"); + Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS), "redirect_uris cannot be empty"); ((List) this.claims.get(OidcClientMetadataClaimNames.REDIRECT_URIS)).forEach( url -> validateURL(url, "redirect_uri must be a valid URL") ); - } - - private void validateGrantTypesClaim() { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), "grant_types must be of type List"); - List grantTypes = (List) this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES); - // If empty, the default is that the Client will use only the authorization_code Grant Type. - if (grantTypes.isEmpty()) { - this.claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, - Collections.singletonList(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())); + if (this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), "grant_types must be of type List"); + Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.GRANT_TYPES), "grant_types cannot be empty"); } - } - - private void validateReponseTypesClaim() { - Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES), "response_types must be of type List"); - List responseTypes = (List) this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES); - //If empty, the default is that the Client will use only the code Response Type. - if (responseTypes.isEmpty()) { - this.claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, Collections.singletonList(OAuth2AuthorizationResponseType.CODE.getValue())); + if (this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES), "response_types must be of type List"); + Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.RESPONSE_TYPES), "response_types cannot be empty"); } - } - - private static void validateURL(Object url, String errorMessage) { - if (URL.class.isAssignableFrom(url.getClass())) { - return; - } - try { - new URI(url.toString()).toURL(); - } catch (Exception ex) { - throw new IllegalArgumentException(errorMessage, ex); + if (this.claims.get(OidcClientMetadataClaimNames.SCOPE) != null) { + Assert.isInstanceOf(List.class, this.claims.get(OidcClientMetadataClaimNames.SCOPE), "scope must be of type List"); + Assert.notEmpty((List) this.claims.get(OidcClientMetadataClaimNames.SCOPE), "scope cannot be empty"); } } @@ -345,5 +340,16 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce List values = (List) this.claims.get(name); valuesConsumer.accept(values); } + + private static void validateURL(Object url, String errorMessage) { + if (URL.class.isAssignableFrom(url.getClass())) { + return; + } + try { + new URI(url.toString()).toURL(); + } catch (Exception ex) { + throw new IllegalArgumentException(errorMessage, ex); + } + } } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java index 5fe633e0..835fecc8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java @@ -15,6 +15,15 @@ */ package org.springframework.security.oauth2.core.oidc.http.converter; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.converter.Converter; @@ -31,31 +40,28 @@ import org.springframework.security.oauth2.core.converter.ClaimTypeConverter; import org.springframework.security.oauth2.core.oidc.OidcClientMetadataClaimNames; import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; import org.springframework.util.Assert; - -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** - * A {@link HttpMessageConverter} for an {@link OidcClientRegistration OpenID Client Registration Response}. + * A {@link HttpMessageConverter} for an {@link OidcClientRegistration OpenID Client Registration Request and Response}. * * @author Ovidiu Popa + * @author Joe Grandja + * @since 0.1.1 * @see AbstractHttpMessageConverter * @see OidcClientRegistration - * @since 0.1.1 */ public class OidcClientRegistrationHttpMessageConverter extends AbstractHttpMessageConverter { - private static final ParameterizedTypeReference> STRING_OBJECT_MAP = - new ParameterizedTypeReference>() { - }; + private static final ParameterizedTypeReference> STRING_OBJECT_MAP = new ParameterizedTypeReference>() { + }; - private Converter, OidcClientRegistration> clientRegistrationConverter = - new OidcClientRegistrationConverter(); - - private Converter> clientRegistrationParametersConverter = OidcClientRegistration::getClaims; private final GenericHttpMessageConverter jsonMessageConverter = HttpMessageConverters.getJsonMessageConverter(); + private Converter, OidcClientRegistration> clientRegistrationConverter = new MapOidcClientRegistrationConverter(); + private Converter> clientRegistrationParametersConverter = new OidcClientRegistrationMapConverter(); + public OidcClientRegistrationHttpMessageConverter() { super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); } @@ -70,52 +76,46 @@ public class OidcClientRegistrationHttpMessageConverter extends AbstractHttpMess protected OidcClientRegistration readInternal(Class clazz, HttpInputMessage inputMessage) throws HttpMessageNotReadableException { try { - Map clientRegistrationParameters = - (Map) this.jsonMessageConverter.read(STRING_OBJECT_MAP.getType(), null, inputMessage); + Map clientRegistrationParameters = (Map) this.jsonMessageConverter + .read(STRING_OBJECT_MAP.getType(), null, inputMessage); return this.clientRegistrationConverter.convert(clientRegistrationParameters); } catch (Exception ex) { throw new HttpMessageNotReadableException( - "An error occurred reading the OpenID Client Registration Request: " + ex.getMessage(), ex, inputMessage); + "An error occurred reading the OpenID Client Registration: " + ex.getMessage(), ex, inputMessage); } } @Override - protected void writeInternal(OidcClientRegistration oidcClientRegistration, HttpOutputMessage outputMessage) + protected void writeInternal(OidcClientRegistration clientRegistration, HttpOutputMessage outputMessage) throws HttpMessageNotWritableException { - try { - Map claims = clientRegistrationParametersConverter.convert(oidcClientRegistration); - this.jsonMessageConverter.write( - claims, - STRING_OBJECT_MAP.getType(), - MediaType.APPLICATION_JSON, - outputMessage - ); + Map clientRegistrationParameters = this.clientRegistrationParametersConverter + .convert(clientRegistration); + this.jsonMessageConverter.write(clientRegistrationParameters, STRING_OBJECT_MAP.getType(), + MediaType.APPLICATION_JSON, outputMessage); } catch (Exception ex) { throw new HttpMessageNotWritableException( - "An error occurred writing the OpenID Client Registration response: " + ex.getMessage(), ex); + "An error occurred writing the OpenID Client Registration: " + ex.getMessage(), ex); } - } /** - * Sets the {@link Converter} used for converting the OpenID Client Registration parameters - * to an {@link OidcClientRegistration}. + * Sets the {@link Converter} used for converting the OpenID Client Registration parameters to an {@link OidcClientRegistration}. * - * @param clientRegistrationConverter the {@link Converter} used for converting to an - * {@link OidcClientRegistration} + * @param clientRegistrationConverter the {@link Converter} used for converting to an {@link OidcClientRegistration} */ - public void setClientRegistrationConverter(Converter, OidcClientRegistration> clientRegistrationConverter) { + public final void setClientRegistrationConverter( + Converter, OidcClientRegistration> clientRegistrationConverter) { Assert.notNull(clientRegistrationConverter, "clientRegistrationConverter cannot be null"); this.clientRegistrationConverter = clientRegistrationConverter; } /** - * Sets the {@link Converter} used for converting the {@link OidcClientRegistration} to a - * {@code Map} representation of the OpenID Client Registration Response. + * Sets the {@link Converter} used for converting the {@link OidcClientRegistration} + * to a {@code Map} representation of the OpenID Client Registration parameters. * * @param clientRegistrationParametersConverter the {@link Converter} used for converting to a - * {@code Map} representation of the OpenID Client Registration Response + * {@code Map} representation of the OpenID Client Registration parameters */ public final void setClientRegistrationParametersConverter( Converter> clientRegistrationParametersConverter) { @@ -123,35 +123,87 @@ public class OidcClientRegistrationHttpMessageConverter extends AbstractHttpMess this.clientRegistrationParametersConverter = clientRegistrationParametersConverter; } - private static final class OidcClientRegistrationConverter implements Converter, OidcClientRegistration> { + private static final class MapOidcClientRegistrationConverter + implements Converter, OidcClientRegistration> { + private static final ClaimConversionService CLAIM_CONVERSION_SERVICE = ClaimConversionService.getSharedInstance(); private static final TypeDescriptor OBJECT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Object.class); private static final TypeDescriptor STRING_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(String.class); + private static final TypeDescriptor INSTANT_TYPE_DESCRIPTOR = TypeDescriptor.valueOf(Instant.class); + private static final Converter INSTANT_CONVERTER = getConverter(INSTANT_TYPE_DESCRIPTOR); private final ClaimTypeConverter claimTypeConverter; - private OidcClientRegistrationConverter() { + private MapOidcClientRegistrationConverter() { + Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); Converter collectionStringConverter = getConverter( TypeDescriptor.collection(Collection.class, STRING_TYPE_DESCRIPTOR)); - Converter stringConverter = getConverter(STRING_TYPE_DESCRIPTOR); Map> claimConverters = new HashMap<>(); - claimConverters.put(OidcClientMetadataClaimNames.REDIRECT_URIS, collectionStringConverter); - claimConverters.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, collectionStringConverter); - claimConverters.put(OidcClientMetadataClaimNames.GRANT_TYPES, collectionStringConverter); + claimConverters.put(OidcClientMetadataClaimNames.CLIENT_ID, stringConverter); + claimConverters.put(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, INSTANT_CONVERTER); + claimConverters.put(OidcClientMetadataClaimNames.CLIENT_SECRET, stringConverter); + claimConverters.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, MapOidcClientRegistrationConverter::convertClientSecretExpiresAt); claimConverters.put(OidcClientMetadataClaimNames.CLIENT_NAME, stringConverter); - claimConverters.put(OidcClientMetadataClaimNames.SCOPE, stringConverter); + claimConverters.put(OidcClientMetadataClaimNames.REDIRECT_URIS, collectionStringConverter); claimConverters.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, stringConverter); + claimConverters.put(OidcClientMetadataClaimNames.GRANT_TYPES, collectionStringConverter); + claimConverters.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, collectionStringConverter); + claimConverters.put(OidcClientMetadataClaimNames.SCOPE, MapOidcClientRegistrationConverter::convertScope); + claimConverters.put(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG, stringConverter); this.claimTypeConverter = new ClaimTypeConverter(claimConverters); } @Override public OidcClientRegistration convert(Map source) { Map parsedClaims = this.claimTypeConverter.convert(source); + Object clientSecretExpiresAt = parsedClaims.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); + if (clientSecretExpiresAt instanceof Number && clientSecretExpiresAt.equals(0)) { + parsedClaims.remove(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); + } return OidcClientRegistration.withClaims(parsedClaims).build(); } private static Converter getConverter(TypeDescriptor targetDescriptor) { return source -> CLAIM_CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor); } + + private static Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) { + if (clientSecretExpiresAt != null && String.valueOf(clientSecretExpiresAt).equals("0")) { + // 0 indicates that client_secret_expires_at does not expire + return null; + } + return (Instant) INSTANT_CONVERTER.convert(clientSecretExpiresAt); + } + + private static List convertScope(Object scope) { + if (scope == null) { + return Collections.emptyList(); + } + return Arrays.asList(StringUtils.delimitedListToStringArray(scope.toString(), " ")); + } } + + private static final class OidcClientRegistrationMapConverter + implements Converter> { + + @Override + public Map convert(OidcClientRegistration source) { + Map responseClaims = new LinkedHashMap<>(source.getClaims()); + if (source.getClientIdIssuedAt() != null) { + responseClaims.put(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, source.getClientIdIssuedAt().getEpochSecond()); + } + if (source.getClientSecret() != null) { + long clientSecretExpiresAt = 0; + if (source.getClientSecretExpiresAt() != null) { + clientSecretExpiresAt = source.getClientSecretExpiresAt().getEpochSecond(); + } + responseClaims.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + } + if (!CollectionUtils.isEmpty(source.getScopes())) { + responseClaims.put(OidcClientMetadataClaimNames.SCOPE, StringUtils.collectionToDelimitedString(source.getScopes(), " ")); + } + return responseClaims; + } + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProvider.java deleted file mode 100644 index d73d1d02..00000000 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProvider.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.authentication; - -import org.springframework.security.authentication.AuthenticationProvider; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2TokenType; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; -import org.springframework.util.Assert; - -/** - * An {@link AuthenticationProvider} implementation for OpenID Client Registration Endpoint. - * - * @author Ovidiu Popa - * @since 0.1.1 - * @see JwtAuthenticationToken - * @see OAuth2AuthorizationService - */ -public class OidcClientRegistrationAuthenticationProvider implements AuthenticationProvider { - - private static final String CLIENT_CREATE_SCOPE = "client.create"; - private final OAuth2AuthorizationService authorizationService; - - /** - * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters. - * - * @param authorizationService the authorization service - */ - public OidcClientRegistrationAuthenticationProvider(OAuth2AuthorizationService authorizationService) { - Assert.notNull(authorizationService, "authorizationService cannot be null"); - this.authorizationService = authorizationService; - } - - @Override - public Authentication authenticate(Authentication authentication) throws AuthenticationException { - JwtAuthenticationToken jwtAuthenticationToken = - (JwtAuthenticationToken) authentication; - - String tokenValue = jwtAuthenticationToken.getToken().getTokenValue(); - OAuth2Authorization authorization = this.authorizationService.findByToken(tokenValue, OAuth2TokenType.ACCESS_TOKEN); - - if (authorization == null) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); - } - - OAuth2Authorization.Token authorizationAccessToken = - authorization.getAccessToken(); - if (authorizationAccessToken.isInvalidated()) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); - } - OAuth2AccessToken accessToken = authorizationAccessToken.getToken(); - if (!accessToken.getScopes().contains(CLIENT_CREATE_SCOPE)) { - throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT)); - } - - authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, accessToken); - this.authorizationService.save(authorization); - - return jwtAuthenticationToken; - } - - @Override - public boolean supports(Class authentication) { - return JwtAuthenticationToken.class.isAssignableFrom(authentication); - } -} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java index 4a840487..38edf7df 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * A {@link RegisteredClientRepository} that stores {@link RegisteredClient}(s) in-memory. @@ -30,6 +31,8 @@ import org.springframework.util.Assert; * NOTE: This implementation is recommended ONLY to be used during development/testing. * * @author Anoop Garlapati + * @author Ovidiu Popa + * @author Joe Grandja * @see RegisteredClientRepository * @see RegisteredClient * @since 0.0.1 @@ -58,23 +61,22 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien ConcurrentHashMap clientIdRegistrationMapResult = new ConcurrentHashMap<>(); for (RegisteredClient registration : registrations) { Assert.notNull(registration, "registration cannot be null"); - String id = registration.getId(); - if (idRegistrationMapResult.containsKey(id)) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate identifier: " + id); - } - String clientId = registration.getClientId(); - if (clientIdRegistrationMapResult.containsKey(clientId)) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate client identifier: " + clientId); - } - idRegistrationMapResult.put(id, registration); - clientIdRegistrationMapResult.put(clientId, registration); + assertUniqueIdentifiers(registration, idRegistrationMapResult); + idRegistrationMapResult.put(registration.getId(), registration); + clientIdRegistrationMapResult.put(registration.getClientId(), registration); } this.idRegistrationMap = idRegistrationMapResult; this.clientIdRegistrationMap = clientIdRegistrationMapResult; } + @Override + public void save(RegisteredClient registeredClient) { + Assert.notNull(registeredClient, "registeredClient cannot be null"); + assertUniqueIdentifiers(registeredClient, this.idRegistrationMap); + this.idRegistrationMap.put(registeredClient.getId(), registeredClient); + this.clientIdRegistrationMap.put(registeredClient.getClientId(), registeredClient); + } + @Nullable @Override public RegisteredClient findById(String id) { @@ -89,20 +91,22 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien return this.clientIdRegistrationMap.get(clientId); } - @Override - public void saveClient(RegisteredClient registeredClient) { - Assert.notNull(registeredClient, "registeredClient cannot be null"); - String id = registeredClient.getId(); - if (idRegistrationMap.containsKey(id)) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate identifier: " + id); - } - String clientId = registeredClient.getClientId(); - if (clientIdRegistrationMap.containsKey(clientId)) { - throw new IllegalArgumentException("Registered client must be unique. " + - "Found duplicate client identifier: " + clientId); - } - this.idRegistrationMap.put(registeredClient.getId(), registeredClient); - this.clientIdRegistrationMap.put(registeredClient.getClientId(), registeredClient); + private void assertUniqueIdentifiers(RegisteredClient registeredClient, Map registrations) { + registrations.values().forEach(registration -> { + if (registeredClient.getId().equals(registration.getId())) { + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate identifier: " + registeredClient.getId()); + } + if (registeredClient.getClientId().equals(registration.getClientId())) { + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate client identifier: " + registeredClient.getClientId()); + } + if (StringUtils.hasText(registeredClient.getClientSecret()) && + registeredClient.getClientSecret().equals(registration.getClientSecret())) { + throw new IllegalArgumentException("Registered client must be unique. " + + "Found duplicate client secret for identifier: " + registeredClient.getId()); + } + }); } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java index a787ea59..1823b543 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java @@ -18,12 +18,14 @@ package org.springframework.security.oauth2.server.authorization.client; import java.io.Serializable; import java.net.URI; import java.net.URISyntaxException; +import java.time.Instant; import java.util.Collections; import java.util.HashSet; import java.util.Objects; import java.util.Set; import java.util.function.Consumer; +import org.springframework.lang.Nullable; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.Version; @@ -31,6 +33,7 @@ import org.springframework.security.oauth2.server.authorization.config.ClientSet import org.springframework.security.oauth2.server.authorization.config.TokenSettings; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; /** * A representation of a client registration with an OAuth 2.0 Authorization Server. @@ -44,7 +47,10 @@ public class RegisteredClient implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String id; private String clientId; + private Instant clientIdIssuedAt; private String clientSecret; + private Instant clientSecretExpiresAt; + private String clientName; private Set clientAuthenticationMethods; private Set authorizationGrantTypes; private Set redirectUris; @@ -73,6 +79,16 @@ public class RegisteredClient implements Serializable { return this.clientId; } + /** + * Returns the time at which the client identifier was issued. + * + * @return the time at which the client identifier was issued + */ + @Nullable + public Instant getClientIdIssuedAt() { + return this.clientIdIssuedAt; + } + /** * Returns the client secret. * @@ -82,6 +98,25 @@ public class RegisteredClient implements Serializable { return this.clientSecret; } + /** + * Returns the time at which the client secret expires or {@code null} if it does not expire. + * + * @return the time at which the client secret expires or {@code null} if it does not expire + */ + @Nullable + public Instant getClientSecretExpiresAt() { + return this.clientSecretExpiresAt; + } + + /** + * Returns the client name. + * + * @return the client name + */ + public String getClientName() { + return this.clientName; + } + /** * Returns the {@link ClientAuthenticationMethod authentication method(s)} that the client may use. * @@ -147,7 +182,10 @@ public class RegisteredClient implements Serializable { RegisteredClient that = (RegisteredClient) obj; return Objects.equals(this.id, that.id) && Objects.equals(this.clientId, that.clientId) && + Objects.equals(this.clientIdIssuedAt, that.clientIdIssuedAt) && Objects.equals(this.clientSecret, that.clientSecret) && + Objects.equals(this.clientSecretExpiresAt, that.clientSecretExpiresAt) && + Objects.equals(this.clientName, that.clientName) && Objects.equals(this.clientAuthenticationMethods, that.clientAuthenticationMethods) && Objects.equals(this.authorizationGrantTypes, that.authorizationGrantTypes) && Objects.equals(this.redirectUris, that.redirectUris) && @@ -158,8 +196,8 @@ public class RegisteredClient implements Serializable { @Override public int hashCode() { - return Objects.hash(this.id, this.clientId, this.clientSecret, - this.clientAuthenticationMethods, this.authorizationGrantTypes, this.redirectUris, + return Objects.hash(this.id, this.clientId, this.clientIdIssuedAt, this.clientSecret, this.clientSecretExpiresAt, + this.clientName, this.clientAuthenticationMethods, this.authorizationGrantTypes, this.redirectUris, this.scopes, this.clientSettings.settings(), this.tokenSettings.settings()); } @@ -168,6 +206,7 @@ public class RegisteredClient implements Serializable { return "RegisteredClient {" + "id='" + this.id + '\'' + ", clientId='" + this.clientId + '\'' + + ", clientName='" + this.clientName + '\'' + ", clientAuthenticationMethods=" + this.clientAuthenticationMethods + ", authorizationGrantTypes=" + this.authorizationGrantTypes + ", redirectUris=" + this.redirectUris + @@ -206,7 +245,10 @@ public class RegisteredClient implements Serializable { private static final long serialVersionUID = Version.SERIAL_VERSION_UID; private String id; private String clientId; + private Instant clientIdIssuedAt; private String clientSecret; + private Instant clientSecretExpiresAt; + private String clientName; private Set clientAuthenticationMethods = new HashSet<>(); private Set authorizationGrantTypes = new HashSet<>(); private Set redirectUris = new HashSet<>(); @@ -221,7 +263,10 @@ public class RegisteredClient implements Serializable { protected Builder(RegisteredClient registeredClient) { this.id = registeredClient.id; this.clientId = registeredClient.clientId; + this.clientIdIssuedAt = registeredClient.clientIdIssuedAt; this.clientSecret = registeredClient.clientSecret; + this.clientSecretExpiresAt = registeredClient.clientSecretExpiresAt; + this.clientName = registeredClient.clientName; if (!CollectionUtils.isEmpty(registeredClient.clientAuthenticationMethods)) { this.clientAuthenticationMethods.addAll(registeredClient.clientAuthenticationMethods); } @@ -260,6 +305,17 @@ public class RegisteredClient implements Serializable { return this; } + /** + * Sets the time at which the client identifier was issued. + * + * @param clientIdIssuedAt the time at which the client identifier was issued + * @return the {@link Builder} + */ + public Builder clientIdIssuedAt(Instant clientIdIssuedAt) { + this.clientIdIssuedAt = clientIdIssuedAt; + return this; + } + /** * Sets the client secret. * @@ -271,6 +327,28 @@ public class RegisteredClient implements Serializable { return this; } + /** + * Sets the time at which the client secret expires or {@code null} if it does not expire. + * + * @param clientSecretExpiresAt the time at which the client secret expires or {@code null} if it does not expire + * @return the {@link Builder} + */ + public Builder clientSecretExpiresAt(Instant clientSecretExpiresAt) { + this.clientSecretExpiresAt = clientSecretExpiresAt; + return this; + } + + /** + * Sets the client name. + * + * @param clientName the client name + * @return the {@link Builder} + */ + public Builder clientName(String clientName) { + this.clientName = clientName; + return this; + } + /** * Adds an {@link ClientAuthenticationMethod authentication method} * the client may use when authenticating with the authorization server. @@ -400,6 +478,9 @@ public class RegisteredClient implements Serializable { if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty"); } + if (!StringUtils.hasText(this.clientName)) { + this.clientName = this.id; + } if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) { this.clientAuthenticationMethods.add(ClientAuthenticationMethod.BASIC); } @@ -413,7 +494,10 @@ public class RegisteredClient implements Serializable { registeredClient.id = this.id; registeredClient.clientId = this.clientId; + registeredClient.clientIdIssuedAt = this.clientIdIssuedAt; registeredClient.clientSecret = this.clientSecret; + registeredClient.clientSecretExpiresAt = this.clientSecretExpiresAt; + registeredClient.clientName = this.clientName; registeredClient.clientAuthenticationMethods = Collections.unmodifiableSet( new HashSet<>(this.clientAuthenticationMethods)); registeredClient.authorizationGrantTypes = Collections.unmodifiableSet( diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java index 60182f8f..f1905445 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java @@ -22,11 +22,19 @@ import org.springframework.lang.Nullable; * * @author Joe Grandja * @author Anoop Garlapati + * @author Ovidiu Popa * @see RegisteredClient * @since 0.0.1 */ public interface RegisteredClientRepository { + /** + * Saves the registered client. + * + * @param registeredClient the {@link RegisteredClient} + */ + void save(RegisteredClient registeredClient); + /** * Returns the registered client identified by the provided {@code id}, * or {@code null} if not found. @@ -47,11 +55,4 @@ public interface RegisteredClientRepository { @Nullable RegisteredClient findByClientId(String clientId); - /** - * Saves a new registered client - * - * @param registeredClient the {@link RegisteredClient} to be saved - */ - void saveClient(RegisteredClient registeredClient); - } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java index 2dbcba43..03d17ec1 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettings.java @@ -34,7 +34,6 @@ public class ProviderSettings extends Settings { public static final String TOKEN_REVOCATION_ENDPOINT = PROVIDER_SETTING_BASE.concat("token-revocation-endpoint"); public static final String TOKEN_INTROSPECTION_ENDPOINT = PROVIDER_SETTING_BASE.concat("token-introspection-endpoint"); public static final String OIDC_CLIENT_REGISTRATION_ENDPOINT = PROVIDER_SETTING_BASE.concat("oidc-client-registration-endpoint"); - public static final String ENABLE_OIDC_CLIENT_REGISTRATION_ENDPOINT = PROVIDER_SETTING_BASE.concat("enable-oidc-client-registration-endpoint"); /** * Constructs a {@code ProviderSettings}. @@ -167,45 +166,24 @@ public class ProviderSettings extends Settings { } /** - * Returns the Provider's OAuth 2.0 OIDC Client Registration endpoint. The default is {@code /connect/register}. + * Returns the Provider's OpenID Connect 1.0 Client Registration endpoint. The default is {@code /connect/register}. * - * @return the OIDC Client Registration endpoint + * @return the OpenID Connect 1.0 Client Registration endpoint */ public String oidcClientRegistrationEndpoint() { return setting(OIDC_CLIENT_REGISTRATION_ENDPOINT); } /** - * Sets the Provider's OAuth 2.0 OIDC Client Registration endpoint. + * Sets the Provider's OpenID Connect 1.0 Client Registration endpoint. * - * @param oidcClientRegistrationEndpoint the Token Revocation endpoint + * @param oidcClientRegistrationEndpoint the OpenID Connect 1.0 Client Registration endpoint * @return the {@link ProviderSettings} for further configuration */ public ProviderSettings oidcClientRegistrationEndpoint(String oidcClientRegistrationEndpoint) { return setting(OIDC_CLIENT_REGISTRATION_ENDPOINT, oidcClientRegistrationEndpoint); } - /** - * Returns {@code true} if the OIDC Client Registration endpoint is enabled. - * The default is {@code false}. - * - * @return {@code true} if the OIDC Client Registration endpoint is enabled, {@code false} otherwise - */ - public boolean isOidClientRegistrationEndpointEnabled() { - return setting(ENABLE_OIDC_CLIENT_REGISTRATION_ENDPOINT); - } - - /** - * Set to {@code true} if the OIDC Client Registration Endpoint should be enabled. - * - * @param oidClientRegistrationEndpointEnabled {@code true} if the OIDC Client Registration endpoint should enabled - * @return the {@link ProviderSettings} - */ - public ProviderSettings isOidClientRegistrationEndpointEnabled(boolean oidClientRegistrationEndpointEnabled) { - setting(ENABLE_OIDC_CLIENT_REGISTRATION_ENDPOINT, oidClientRegistrationEndpointEnabled); - return this; - } - protected static Map defaultSettings() { Map settings = new HashMap<>(); settings.put(AUTHORIZATION_ENDPOINT, "/oauth2/authorize"); @@ -214,7 +192,6 @@ public class ProviderSettings extends Settings { settings.put(TOKEN_REVOCATION_ENDPOINT, "/oauth2/revoke"); settings.put(TOKEN_INTROSPECTION_ENDPOINT, "/oauth2/introspect"); settings.put(OIDC_CLIENT_REGISTRATION_ENDPOINT, "/connect/register"); - settings.put(ENABLE_OIDC_CLIENT_REGISTRATION_ENDPOINT, false); return settings; } } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java index cab922e2..f0f27345 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/TokenSettings.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,14 @@ */ package org.springframework.security.oauth2.server.authorization.config; -import org.springframework.util.Assert; - import java.time.Duration; import java.util.HashMap; import java.util.Map; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.util.Assert; + /** * A facility for token configuration settings. * @@ -33,6 +35,7 @@ public class TokenSettings extends Settings { public static final String ACCESS_TOKEN_TIME_TO_LIVE = TOKEN_SETTING_BASE.concat("access-token-time-to-live"); public static final String REUSE_REFRESH_TOKENS = TOKEN_SETTING_BASE.concat("reuse-refresh-tokens"); public static final String REFRESH_TOKEN_TIME_TO_LIVE = TOKEN_SETTING_BASE.concat("refresh-token-time-to-live"); + public static final String ID_TOKEN_SIGNATURE_ALGORITHM = TOKEN_SETTING_BASE.concat("id-token-signature-algorithm"); /** * Constructs a {@code TokenSettings}. @@ -114,11 +117,35 @@ public class TokenSettings extends Settings { return this; } + /** + * Returns the {@link SignatureAlgorithm JWS} algorithm for signing the {@link OidcIdToken ID Token}. + * The default is {@link SignatureAlgorithm#RS256 RS256}. + * + * @return the {@link SignatureAlgorithm JWS} algorithm for signing the {@link OidcIdToken ID Token} + */ + public SignatureAlgorithm idTokenSignatureAlgorithm() { + return setting(ID_TOKEN_SIGNATURE_ALGORITHM); + } + + /** + * Sets the {@link SignatureAlgorithm JWS} algorithm for signing the {@link OidcIdToken ID Token}. + * + * @param idTokenSignatureAlgorithm the {@link SignatureAlgorithm JWS} algorithm for signing the {@link OidcIdToken ID Token} + * @return the {@link TokenSettings} + */ + public TokenSettings idTokenSignatureAlgorithm(SignatureAlgorithm idTokenSignatureAlgorithm) { + Assert.notNull(idTokenSignatureAlgorithm, "idTokenSignatureAlgorithm cannot be null"); + setting(ID_TOKEN_SIGNATURE_ALGORITHM, idTokenSignatureAlgorithm); + return this; + } + protected static Map defaultSettings() { Map settings = new HashMap<>(); settings.put(ACCESS_TOKEN_TIME_TO_LIVE, Duration.ofMinutes(5)); settings.put(REUSE_REFRESH_TOKENS, true); settings.put(REFRESH_TOKEN_TIME_TO_LIVE, Duration.ofMinutes(60)); + settings.put(ID_TOKEN_SIGNATURE_ALGORITHM, SignatureAlgorithm.RS256); return settings; } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcAuthenticationProviderUtils.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcAuthenticationProviderUtils.java new file mode 100644 index 00000000..0301faf7 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcAuthenticationProviderUtils.java @@ -0,0 +1,63 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.oauth2.core.AbstractOAuth2Token; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; + +/** + * Utility methods for the OpenID Connect 1.0 {@link AuthenticationProvider}'s. + * + * @author Joe Grandja + * @since 0.1.1 + */ +final class OidcAuthenticationProviderUtils { + + private OidcAuthenticationProviderUtils() { + } + + static OAuth2Authorization invalidate( + OAuth2Authorization authorization, T token) { + + // @formatter:off + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization) + .token(token, + (metadata) -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + + if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) { + authorizationBuilder.token( + authorization.getAccessToken().getToken(), + (metadata) -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + + OAuth2Authorization.Token authorizationCode = + authorization.getToken(OAuth2AuthorizationCode.class); + if (authorizationCode != null && !authorizationCode.isInvalidated()) { + authorizationBuilder.token( + authorizationCode.getToken(), + (metadata) -> + metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); + } + } + // @formatter:on + + return authorizationBuilder.build(); + } +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java new file mode 100644 index 00000000..0d23d5b4 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java @@ -0,0 +1,218 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.time.Instant; +import java.util.Base64; +import java.util.Collection; +import java.util.UUID; + +import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; +import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; +import org.springframework.security.crypto.keygen.StringKeyGenerator; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.resource.authentication.AbstractOAuth2TokenAuthenticationToken; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; + +/** + * An {@link AuthenticationProvider} implementation for OpenID Connect Dynamic Client Registration 1.0. + * + * @author Ovidiu Popa + * @author Joe Grandja + * @since 0.1.1 + * @see RegisteredClientRepository + * @see OAuth2AuthorizationService + * @see 3. Client Registration Endpoint + */ +public class OidcClientRegistrationAuthenticationProvider implements AuthenticationProvider { + private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 32); + private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator( + Base64.getUrlEncoder().withoutPadding(), 48); + private static final String DEFAULT_AUTHORIZED_SCOPE = "client.create"; + private final RegisteredClientRepository registeredClientRepository; + private final OAuth2AuthorizationService authorizationService; + + /** + * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters. + * + * @param registeredClientRepository the repository of registered clients + * @param authorizationService the authorization service + */ + public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository, + OAuth2AuthorizationService authorizationService) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + Assert.notNull(authorizationService, "authorizationService cannot be null"); + this.registeredClientRepository = registeredClientRepository; + this.authorizationService = authorizationService; + } + + @Override + public Authentication authenticate(Authentication authentication) throws AuthenticationException { + OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = + (OidcClientRegistrationAuthenticationToken) authentication; + + // Validate the "initial" access token + AbstractOAuth2TokenAuthenticationToken accessTokenAuthentication = null; + if (AbstractOAuth2TokenAuthenticationToken.class.isAssignableFrom(clientRegistrationAuthentication.getPrincipal().getClass())) { + accessTokenAuthentication = (AbstractOAuth2TokenAuthenticationToken) clientRegistrationAuthentication.getPrincipal(); + } + if (accessTokenAuthentication == null || !accessTokenAuthentication.isAuthenticated()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + } + + String accessTokenValue = accessTokenAuthentication.getToken().getTokenValue(); + + OAuth2Authorization authorization = this.authorizationService.findByToken( + accessTokenValue, OAuth2TokenType.ACCESS_TOKEN); + if (authorization == null) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + } + + OAuth2Authorization.Token authorizedAccessToken = authorization.getAccessToken(); + if (!authorizedAccessToken.isActive()) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN)); + } + + if (!isAuthorized(authorizedAccessToken)) { + throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)); + } + + RegisteredClient registeredClient = create(clientRegistrationAuthentication.getClientRegistration()); + this.registeredClientRepository.save(registeredClient); + + // Invalidate the "initial" access token as it can only be used once + authorization = OidcAuthenticationProviderUtils.invalidate(authorization, authorizedAccessToken.getToken()); + if (authorization.getRefreshToken() != null) { + authorization = OidcAuthenticationProviderUtils.invalidate(authorization, authorization.getRefreshToken().getToken()); + } + this.authorizationService.save(authorization); + + return new OidcClientRegistrationAuthenticationToken( + accessTokenAuthentication, convert(registeredClient)); + } + + @Override + public boolean supports(Class authentication) { + return OidcClientRegistrationAuthenticationToken.class.isAssignableFrom(authentication); + } + + @SuppressWarnings("unchecked") + private static boolean isAuthorized(OAuth2Authorization.Token authorizedAccessToken) { + Object scope = authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE); + return scope != null && ((Collection) scope).contains(DEFAULT_AUTHORIZED_SCOPE); + } + + private static RegisteredClient create(OidcClientRegistration clientRegistration) { + // @formatter:off + RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString()) + .clientId(CLIENT_ID_GENERATOR.generateKey()) + .clientIdIssuedAt(Instant.now()) + .clientSecret(CLIENT_SECRET_GENERATOR.generateKey()) + .clientName(clientRegistration.getClientName()); + + if ("client_secret_post".equals(clientRegistration.getTokenEndpointAuthenticationMethod())) { + // TODO: Use ClientAuthenticationMethod.CLIENT_SECRET_POST in Spring Security 5.5.0 + builder.clientAuthenticationMethod(ClientAuthenticationMethod.POST); + } else { + // TODO: Use ClientAuthenticationMethod.CLIENT_SECRET_BASIC in Spring Security 5.5.0 + builder.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC); + } + + // TODO Validate redirect_uris and throw OAuth2ErrorCodes2.INVALID_REDIRECT_URI on error + builder.redirectUris(redirectUris -> + redirectUris.addAll(clientRegistration.getRedirectUris())); + + if (!CollectionUtils.isEmpty(clientRegistration.getGrantTypes())) { + builder.authorizationGrantTypes(authorizationGrantTypes -> + clientRegistration.getGrantTypes().forEach(grantType -> + authorizationGrantTypes.add(new AuthorizationGrantType(grantType)))); + } else { + builder.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); + } + if (CollectionUtils.isEmpty(clientRegistration.getResponseTypes()) || + clientRegistration.getResponseTypes().contains(OAuth2AuthorizationResponseType.CODE.getValue())) { + builder.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE); + } + + if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) { + builder.scopes(scopes -> + scopes.addAll(clientRegistration.getScopes())); + } + + builder + .clientSettings(clientSettings -> + clientSettings + .requireProofKey(true) + .requireUserConsent(true)) + .tokenSettings(tokenSettings -> + tokenSettings + .idTokenSignatureAlgorithm(SignatureAlgorithm.RS256)); + + return builder.build(); + // @formatter:on + } + + private static OidcClientRegistration convert(RegisteredClient registeredClient) { + // @formatter:off + OidcClientRegistration.Builder builder = OidcClientRegistration.builder() + .clientId(registeredClient.getClientId()) + .clientIdIssuedAt(registeredClient.getClientIdIssuedAt()) + .clientSecret(registeredClient.getClientSecret()) + .clientName(registeredClient.getClientName()); + + builder.redirectUris(redirectUris -> + redirectUris.addAll(registeredClient.getRedirectUris())); + + builder.grantTypes(grantTypes -> + registeredClient.getAuthorizationGrantTypes().forEach(authorizationGrantType -> + grantTypes.add(authorizationGrantType.getValue()))); + + if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.AUTHORIZATION_CODE)) { + builder.responseType(OAuth2AuthorizationResponseType.CODE.getValue()); + } + + if (!CollectionUtils.isEmpty(registeredClient.getScopes())) { + builder.scopes(scopes -> + scopes.addAll(registeredClient.getScopes())); + } + + builder + .tokenEndpointAuthenticationMethod(registeredClient.getClientAuthenticationMethods().iterator().next().getValue()) + .idTokenSignedResponseAlgorithm(registeredClient.getTokenSettings().idTokenSignatureAlgorithm().getName()); + + return builder.build(); + // @formatter:on + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java new file mode 100644 index 00000000..892a47f6 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationToken.java @@ -0,0 +1,74 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.Collections; + +import org.springframework.security.authentication.AbstractAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.Version; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; +import org.springframework.util.Assert; + +/** + * An {@link Authentication} implementation used for OpenID Connect Dynamic Client Registration 1.0. + * + * @author Joe Grandja + * @since 0.1.1 + * @see AbstractAuthenticationToken + * @see OidcClientRegistration + * @see OidcClientRegistrationAuthenticationProvider + */ +public class OidcClientRegistrationAuthenticationToken extends AbstractAuthenticationToken { + private static final long serialVersionUID = Version.SERIAL_VERSION_UID; + private final Authentication principal; + private final OidcClientRegistration clientRegistration; + + /** + * Constructs an {@code OidcClientRegistrationAuthenticationToken} using the provided parameters. + * + * @param principal the authenticated principal + * @param clientRegistration the client registration + */ + public OidcClientRegistrationAuthenticationToken(Authentication principal, OidcClientRegistration clientRegistration) { + super(Collections.emptyList()); + Assert.notNull(principal, "principal cannot be null"); + Assert.notNull(clientRegistration, "clientRegistration cannot be null"); + this.principal = principal; + this.clientRegistration = clientRegistration; + setAuthenticated(principal.isAuthenticated()); + } + + @Override + public Object getPrincipal() { + return this.principal; + } + + @Override + public Object getCredentials() { + return ""; + } + + /** + * Returns the client registration. + * + * @return the client registration + */ + public OidcClientRegistration getClientRegistration() { + return this.clientRegistration; + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java index 40d13071..67251dde 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java @@ -15,180 +15,130 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web; -import org.springframework.core.convert.converter.Converter; +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; import org.springframework.security.oauth2.core.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import java.io.IOException; -import java.time.Instant; -import java.util.Arrays; -import java.util.List; -import java.util.UUID; -import java.util.stream.Collectors; - /** - * A {@code Filter} that processes OpenID Client Registration Requests. + * A {@code Filter} that processes OpenID Connect Dynamic Client Registration 1.0 Requests. + * * @author Ovidiu Popa + * @author Joe Grandja * @since 0.1.1 * @see OidcClientRegistration - * @see 3.1. Client Registration Request + * @see 3. Client Registration Endpoint */ public class OidcClientRegistrationEndpointFilter extends OncePerRequestFilter { /** * The default endpoint {@code URI} for OpenID Client Registration requests. */ public static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register"; - private static final String SCOPE_CLAIM_DELIMITER = " "; - private final OidcClientRegistrationHttpMessageConverter clientRegistrationHttpMessageConverter = + private final AuthenticationManager authenticationManager; + private final RequestMatcher clientRegistrationEndpointMatcher; + private final HttpMessageConverter clientRegistrationHttpMessageConverter = new OidcClientRegistrationHttpMessageConverter(); - private final RegisteredClientRepository registeredClientRepository; - private final OidcClientRegistrationToRegisteredClientConverter oidcClientToRegisteredClientConverter = - new OidcClientRegistrationToRegisteredClientConverter(); - private final RegisteredClientToOidcClientRegistrationConverter registeredClientToOidcClientConverter = - new RegisteredClientToOidcClientRegistrationConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); - private final RequestMatcher requestMatcher; - private final AuthenticationManager authenticationManager; /** * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters. * - * @param registeredClientRepository the repository of registered clients * @param authenticationManager the authentication manager */ - public OidcClientRegistrationEndpointFilter(RegisteredClientRepository registeredClientRepository, - AuthenticationManager authenticationManager) { - this(registeredClientRepository, authenticationManager, DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI); + public OidcClientRegistrationEndpointFilter(AuthenticationManager authenticationManager) { + this(authenticationManager, DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI); } /** * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters. * - * @param registeredClientRepository the repository of registered clients * @param authenticationManager the authentication manager - * @param oidcClientRegistrationUri the endpoint {@code URI} for OIDC Client Registration requests + * @param clientRegistrationEndpointUri the endpoint {@code URI} for OpenID Client Registration requests */ - public OidcClientRegistrationEndpointFilter(RegisteredClientRepository registeredClientRepository, - AuthenticationManager authenticationManager, String oidcClientRegistrationUri) { - Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + public OidcClientRegistrationEndpointFilter(AuthenticationManager authenticationManager, + String clientRegistrationEndpointUri) { Assert.notNull(authenticationManager, "authenticationManager cannot be null"); - Assert.hasText(oidcClientRegistrationUri, "oidcClientRegistrationUri cannot be empty"); - this.registeredClientRepository = registeredClientRepository; + Assert.hasText(clientRegistrationEndpointUri, "clientRegistrationEndpointUri cannot be empty"); this.authenticationManager = authenticationManager; - this.requestMatcher = new AntPathRequestMatcher( - oidcClientRegistrationUri, - HttpMethod.POST.name() - ); + this.clientRegistrationEndpointMatcher = new AntPathRequestMatcher( + clientRegistrationEndpointUri, HttpMethod.POST.name()); } @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - if (!this.requestMatcher.matches(request)) { + if (!this.clientRegistrationEndpointMatcher.matches(request)) { filterChain.doFilter(request, response); return; } try { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - authenticationManager.authenticate(authentication); - OidcClientRegistration clientRegistrationRequest = - this.clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, new ServletServerHttpRequest(request)); + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + OidcClientRegistration clientRegistration = this.clientRegistrationHttpMessageConverter.read( + OidcClientRegistration.class, new ServletServerHttpRequest(request)); - RegisteredClient registeredClient = this.oidcClientToRegisteredClientConverter - .convert(clientRegistrationRequest); - this.registeredClientRepository.saveClient(registeredClient); + OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = + new OidcClientRegistrationAuthenticationToken(principal, clientRegistration); - OidcClientRegistration convert = this.registeredClientToOidcClientConverter - .convert(registeredClient); + OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = + (OidcClientRegistrationAuthenticationToken) this.authenticationManager.authenticate(clientRegistrationAuthentication); + + sendClientRegistrationResponse(response, clientRegistrationAuthenticationResult.getClientRegistration()); - final ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.CREATED); - this.clientRegistrationHttpMessageConverter.write( - convert, MediaType.APPLICATION_JSON, httpResponse); } catch (OAuth2AuthenticationException ex) { - SecurityContextHolder.clearContext(); sendErrorResponse(response, ex.getError()); + } catch (Exception ex) { + OAuth2Error error = new OAuth2Error( + OAuth2ErrorCodes.INVALID_REQUEST, + "OpenID Client Registration Error: " + ex.getMessage(), + "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError"); + sendErrorResponse(response, error); + } finally { + SecurityContextHolder.clearContext(); } } + private void sendClientRegistrationResponse(HttpServletResponse response, OidcClientRegistration clientRegistration) throws IOException { + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.CREATED); + this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse); + } + private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + HttpStatus httpStatus = HttpStatus.BAD_REQUEST; + if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) { + httpStatus = HttpStatus.UNAUTHORIZED; + } else if (error.getErrorCode().equals(OAuth2ErrorCodes.INSUFFICIENT_SCOPE)) { + httpStatus = HttpStatus.FORBIDDEN; + } ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(HttpStatus.BAD_REQUEST); + httpResponse.setStatusCode(httpStatus); this.errorHttpResponseConverter.write(error, null, httpResponse); } - private static class OidcClientRegistrationToRegisteredClientConverter implements Converter { - - @Override - public RegisteredClient convert(OidcClientRegistration clientRegistration) { - return RegisteredClient.withId(UUID.randomUUID().toString()) - .clientId(UUID.randomUUID().toString()) - .clientSecret(UUID.randomUUID().toString()) - .redirectUris(redirectUris -> - redirectUris.addAll(clientRegistration.getRedirectUris())) - .clientAuthenticationMethod(new ClientAuthenticationMethod(clientRegistration.getTokenEndpointAuthenticationMethod())) - .authorizationGrantTypes(grantTypes -> - grantTypes.addAll(this.grantTypes(clientRegistration))) - .scopes(scopes -> - scopes.addAll(Arrays.asList(clientRegistration.getScope().split(SCOPE_CLAIM_DELIMITER)))) - .clientSettings(clientSettings -> clientSettings.requireUserConsent(true)) - .build(); - } - - private List grantTypes(OidcClientRegistration clientRegistration) { - return clientRegistration.getGrantTypes().stream() - .map(AuthorizationGrantType::new) - .collect(Collectors.toList()); - } - } - - private static class RegisteredClientToOidcClientRegistrationConverter implements Converter { - - @Override - public OidcClientRegistration convert(RegisteredClient source) { - return OidcClientRegistration.builder() - .clientId(source.getClientId()) - .redirectUris(uris -> uris.addAll(source.getRedirectUris())) - .clientIdIssuedAt(Instant.now()) - .clientSecret(source.getClientSecret()) - .clientSecretExpiresAt(Instant.EPOCH) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .grantTypes(grantTypes -> - grantTypes.addAll(source.getAuthorizationGrantTypes().stream().map(AuthorizationGrantType::getValue) - .collect(Collectors.toList())) - ) - .scope(String.join(SCOPE_CLAIM_DELIMITER, source.getScopes())) - .tokenEndpointAuthenticationMethod(source.getClientAuthenticationMethods().iterator().next().getValue()) - .build(); - } - } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/JwkSetTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/JwkSetTests.java index 6d3d66a3..49b887c5 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/JwkSetTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/JwkSetTests.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java index cd7adadb..3becb364 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java @@ -53,11 +53,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter; import org.springframework.security.oauth2.jose.TestJwks; -import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtEncoder; import org.springframework.security.oauth2.jwt.NimbusJwsEncoder; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; @@ -111,7 +110,6 @@ public class OAuth2AuthorizationCodeGrantTests { private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; private static NimbusJwsEncoder jwtEncoder; - private static NimbusJwtDecoder jwtDecoder; private static ProviderSettings providerSettings; private static HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); @@ -122,6 +120,9 @@ public class OAuth2AuthorizationCodeGrantTests { @Autowired private MockMvc mvc; + @Autowired + private JwtDecoder jwtDecoder; + @BeforeClass public static void init() { registeredClientRepository = mock(RegisteredClientRepository.class); @@ -129,7 +130,6 @@ public class OAuth2AuthorizationCodeGrantTests { JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); jwtEncoder = new NimbusJwsEncoder(jwkSource); - jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); providerSettings = new ProviderSettings() .authorizationEndpoint("/test/authorize") .tokenEndpoint("/test/token"); @@ -206,7 +206,7 @@ public class OAuth2AuthorizationCodeGrantTests { registeredClient, authorization, OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI); // Assert user authorities was propagated as claim in JWT - Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue()); + Jwt jwt = this.jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue()); List authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM); Authentication principal = authorization.getAttribute(Principal.class.getName()); Set userAuthorities = principal.getAuthorities().stream() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java index 6324a20a..e478bea1 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java @@ -15,6 +15,10 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; @@ -22,6 +26,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -42,10 +47,6 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenE import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.util.Base64; - import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java index fd8bde81..872841ed 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java @@ -15,6 +15,14 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Base64; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; @@ -22,6 +30,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -58,14 +67,6 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.security.Principal; -import java.util.Base64; -import java.util.List; -import java.util.Set; -import java.util.stream.Collectors; - import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java index 7ede265a..b72b5897 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java @@ -15,6 +15,10 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; @@ -23,6 +27,7 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -48,10 +53,6 @@ import org.springframework.test.web.servlet.MockMvc; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.util.Base64; - import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java index b1127799..97f20e3c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java @@ -15,8 +15,10 @@ */ package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; @@ -25,6 +27,7 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; import org.mockito.ArgumentCaptor; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; @@ -32,6 +35,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.MockHttpOutputMessage; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -48,25 +52,19 @@ import org.springframework.security.oauth2.core.http.converter.OAuth2AccessToken import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; import org.springframework.security.oauth2.core.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.jose.TestJwks; -import org.springframework.security.oauth2.jose.TestKeys; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; +import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; -import java.net.URLEncoder; -import java.nio.charset.StandardCharsets; -import java.util.Base64; -import java.util.Map; - import static org.assertj.core.api.Assertions.assertThat; +import static org.hamcrest.CoreMatchers.containsString; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doNothing; @@ -75,35 +73,24 @@ import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; /** - * Integration tests for OpenID Connect 1.0 Client Registration Endpoint. + * Integration tests for OpenID Connect Dynamic Client Registration 1.0. * * @author Ovidiu Popa - * @since 0.1.1 + * @author Joe Grandja */ public class OidcClientRegistrationTests { - private static final OidcClientRegistration.Builder OIDC_CLIENT_REGISTRATION = OidcClientRegistration.builder() - .redirectUri("https://localhost:8080/client") - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) - .scope("test"); - private static final HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); - - private static final OidcClientRegistrationHttpMessageConverter clientRegistrationHttpMessageConverter = + private static final HttpMessageConverter clientRegistrationHttpMessageConverter = new OidcClientRegistrationHttpMessageConverter(); - - private static final OAuth2TokenType ACCESS_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.ACCESS_TOKEN); - private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; - private static NimbusJwtDecoder jwtDecoder; @Rule public final SpringTestRule spring = new SpringTestRule(); @@ -117,7 +104,6 @@ public class OidcClientRegistrationTests { authorizationService = mock(OAuth2AuthorizationService.class); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); - jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); } @Before @@ -127,63 +113,83 @@ public class OidcClientRegistrationTests { } @Test - public void requestWhenAuthenticatedThenResponseIncludesRegisteredClientDetails() throws Exception { - this.spring.register(AuthorizationServerConfigurationEnabledClientRegistration.class).autowire(); + public void requestWhenClientRegistrationRequestAuthorizedThenClientRegistrationResponse() throws Exception { + this.spring.register(AuthorizationServerConfiguration.class).autowire(); + + // ***** (1) Obtain the "initial" access token used for registering the client + + String clientRegistrationScope = "client.create"; RegisteredClient registeredClient = TestRegisteredClients.registeredClient2() - .scope("client.create").build(); + .scope(clientRegistrationScope) + .build(); when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId()))) .thenReturn(registeredClient); - // get access token + MvcResult mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI) .param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .param(OAuth2ParameterNames.SCOPE, "client.create") + .param(OAuth2ParameterNames.SCOPE, clientRegistrationScope) .header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth( registeredClient.getClientId(), registeredClient.getClientSecret()))) .andExpect(status().isOk()) .andExpect(jsonPath("$.access_token").isNotEmpty()) - .andExpect(jsonPath("$.scope").value("client.create")) + .andExpect(jsonPath("$.scope").value(clientRegistrationScope)) .andReturn(); - //assert get access token + OAuth2AccessToken accessToken = readAccessTokenResponse(mvcResult.getResponse()).getAccessToken(); + verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId())); ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); verify(authorizationService).save(authorizationCaptor.capture()); OAuth2Authorization authorization = authorizationCaptor.getValue(); - MockHttpServletResponse servletResponse = mvcResult.getResponse(); - MockClientHttpResponse httpResponse = new MockClientHttpResponse( - servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); - OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); - String tokenValue = accessTokenResponse.getAccessToken().getTokenValue(); - // prepare register client request - when(authorizationService.findByToken( - eq(authorization.getToken(OAuth2AccessToken.class).getToken().getTokenValue()), - eq(ACCESS_TOKEN_TOKEN_TYPE))) + // ***** (2) Register the client + + when(authorizationService.findByToken(eq(accessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) .thenReturn(authorization); - doNothing().when(registeredClientRepository).saveClient(any(RegisteredClient.class)); - mvcResult = this.mvc.perform(post("/connect/register") - .header(HttpHeaders.AUTHORIZATION, "Bearer " + tokenValue) + doNothing().when(registeredClientRepository).save(any(RegisteredClient.class)); + + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + HttpHeaders httpHeaders = new HttpHeaders(); + httpHeaders.setBearerAuth(accessToken.getTokenValue()); + + // Register the client + mvcResult = this.mvc.perform(post(OidcClientRegistrationEndpointFilter.DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI) + .headers(httpHeaders) .contentType(MediaType.APPLICATION_JSON) - .content(convertToByteArray(OIDC_CLIENT_REGISTRATION.build()))) - .andExpect(status().isCreated()).andReturn(); + .content(getClientRegistrationRequestContent(clientRegistration))) + .andExpect(status().isCreated()) + .andExpect(header().string(HttpHeaders.CACHE_CONTROL, containsString("no-store"))) + .andExpect(header().string(HttpHeaders.PRAGMA, containsString("no-cache"))) + .andReturn(); - servletResponse = mvcResult.getResponse(); - httpResponse = new MockClientHttpResponse( - servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus())); - - OidcClientRegistration result = clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); - - - assertThat(result).isNotNull(); - assertThat(result.getClaimAsString("client_id")).isNotEmpty(); - assertThat(result.getClaimAsString("client_id_issued_at")).isNotEmpty(); - assertThat(result.getClaimAsString("client_secret")).isNotEmpty(); - assertThat(result.getClaimAsString("client_secret_expires_at")).isNotNull().isEqualTo("0.0"); - assertThat(result.getRedirectUris()).isNotEmpty().containsExactly("https://localhost:8080/client"); - assertThat(result.getResponseTypes()).isNotEmpty().containsExactly(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(result.getGrantTypes()).isNotEmpty().containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(result.getTokenEndpointAuthenticationMethod()).isNotEmpty().isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); - assertThat(result.getScope()).isNotEmpty().isEqualTo("test"); + OidcClientRegistration clientRegistrationResponse = readClientRegistrationResponse(mvcResult.getResponse()); + assertThat(clientRegistrationResponse.getClientId()).isNotNull(); + assertThat(clientRegistrationResponse.getClientIdIssuedAt()).isNotNull(); + assertThat(clientRegistrationResponse.getClientSecret()).isNotNull(); + assertThat(clientRegistrationResponse.getClientSecretExpiresAt()).isNull(); + assertThat(clientRegistrationResponse.getClientName()).isEqualTo(clientRegistration.getClientName()); + assertThat(clientRegistrationResponse.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(clientRegistration.getRedirectUris()); + assertThat(clientRegistrationResponse.getGrantTypes()) + .containsExactlyInAnyOrderElementsOf(clientRegistration.getGrantTypes()); + assertThat(clientRegistrationResponse.getResponseTypes()) + .containsExactly(OAuth2AuthorizationResponseType.CODE.getValue()); + assertThat(clientRegistrationResponse.getScopes()) + .containsExactlyInAnyOrderElementsOf(clientRegistration.getScopes()); + assertThat(clientRegistrationResponse.getTokenEndpointAuthenticationMethod()) + .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); + assertThat(clientRegistrationResponse.getIdTokenSignedResponseAlgorithm()) + .isEqualTo(SignatureAlgorithm.RS256.getName()); } private static String encodeBasicAuth(String clientId, String secret) throws Exception { @@ -194,12 +200,22 @@ public class OidcClientRegistrationTests { return new String(encodedBytes, StandardCharsets.UTF_8); } - private static byte[] convertToByteArray(OidcClientRegistration clientRegistration) throws JsonProcessingException { - ObjectMapper objectMapper = new ObjectMapper(); + private static OAuth2AccessTokenResponse readAccessTokenResponse(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); + } - return objectMapper - .writerFor(Map.class) - .writeValueAsBytes(clientRegistration.getClaims()); + private static byte[] getClientRegistrationRequestContent(OidcClientRegistration clientRegistration) throws Exception { + MockHttpOutputMessage httpRequest = new MockHttpOutputMessage(); + clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpRequest); + return httpRequest.getBodyAsBytes(); + } + + private static OidcClientRegistration readClientRegistrationResponse(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); } @EnableWebSecurity @@ -221,21 +237,5 @@ public class OidcClientRegistrationTests { return jwkSource; } - - } - - @EnableWebSecurity - @Import(OAuth2AuthorizationServerConfiguration.class) - static class AuthorizationServerConfigurationEnabledClientRegistration extends AuthorizationServerConfiguration{ - - @Bean - JwtDecoder jwtDecoder() { - return jwtDecoder; - } - - @Bean - ProviderSettings providerSettings() { - return new ProviderSettings().isOidClientRegistrationEndpointEnabled(true); - } } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java index 14aae367..48941f19 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java @@ -54,10 +54,8 @@ import org.springframework.security.oauth2.core.http.converter.OAuth2AccessToken import org.springframework.security.oauth2.core.oidc.OidcScopes; import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.jose.TestJwks; -import org.springframework.security.oauth2.jose.TestKeys; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; @@ -105,7 +103,6 @@ public class OidcTests { private static RegisteredClientRepository registeredClientRepository; private static OAuth2AuthorizationService authorizationService; private static JWKSource jwkSource; - private static NimbusJwtDecoder jwtDecoder; private static HttpMessageConverter accessTokenHttpResponseConverter = new OAuth2AccessTokenResponseHttpMessageConverter(); @@ -115,13 +112,15 @@ public class OidcTests { @Autowired private MockMvc mvc; + @Autowired + private JwtDecoder jwtDecoder; + @BeforeClass public static void init() { registeredClientRepository = mock(RegisteredClientRepository.class); authorizationService = mock(OAuth2AuthorizationService.class); JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK); jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); - jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build(); } @Before @@ -206,7 +205,7 @@ public class OidcTests { OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse); // Assert user authorities was propagated as claim in ID Token - Jwt idToken = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); + Jwt idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN)); List authoritiesClaim = idToken.getClaim(AUTHORITIES_CLAIM); Authentication principal = authorization.getAttribute(Principal.class.getName()); Set userAuthorities = principal.getAuthorities().stream() @@ -275,10 +274,6 @@ public class OidcTests { }; } - @Bean - JwtDecoder jwtDecoder(){ - return jwtDecoder; - } } @EnableWebSecurity diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java index d6996d04..60438224 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java @@ -15,317 +15,384 @@ */ package org.springframework.security.oauth2.core.oidc; -import org.junit.Test; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; - -import java.net.URL; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.List; -import java.util.Map; + +import org.junit.Test; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** - * Tests for {@link OidcClientRegistration} + * Tests for {@link OidcClientRegistration}. * * @author Ovidiu Popa - * @since 0.1.1 + * @author Joe Grandja */ public class OidcClientRegistrationTests { - - private final OidcClientRegistration.Builder clientRegistrationBuilder = - OidcClientRegistration.builder(); + // @formatter:off + private final OidcClientRegistration.Builder minimalBuilder = + OidcClientRegistration.builder() + .redirectUri("https://client.example.com"); + // @formatter:on @Test - public void buildWhenAllRequiredClaimsAndAdditionalClaimsThenCreated() { + public void buildWhenAllClaimsProvidedThenCreated() { + // @formatter:off + Instant clientIdIssuedAt = Instant.now(); + Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS); OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") + .clientId("client-id") + .clientIdIssuedAt(clientIdIssuedAt) + .clientSecret("client-secret") + .clientSecretExpiresAt(clientSecretExpiresAt) + .clientName("client-name") + .redirectUri("https://client.example.com") + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .scope("test read") - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) + .scope("scope1") + .scope("scope2") + .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) + .claim("a-claim", "a-value") .build(); + // @formatter:on - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .contains( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - ); - assertThat(clientRegistration.getResponseTypes()) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isEqualTo("test read"); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); - + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(clientIdIssuedAt); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo("basic"); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getIdTokenSignedResponseAlgorithm()).isEqualTo("RS256"); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); } @Test - public void buildWhenAllRequiredClaimsThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(clientRegistration.getResponseTypes()) - .containsOnly(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isNull(); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); + public void buildWhenOnlyRequiredClaimsProvidedThenCreated() { + OidcClientRegistration clientRegistration = this.minimalBuilder.build(); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); } @Test - public void buildWhenAllRequiredClaimsAndAuthorizationGrantTypeButMissingResponseTypeThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(clientRegistration.getResponseTypes()) - .containsOnly(OAuth2AuthorizationResponseType.CODE.getValue()); - } - - @Test - public void buildWhenAllRequiredClaimsAndEmptyGrantTypeListButMissingResponseTypeThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .grantTypes(List::clear) - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(clientRegistration.getResponseTypes()) - .containsOnly(OAuth2AuthorizationResponseType.CODE.getValue()); - } - - @Test - public void buildWhenAllRequiredClaimsAndResponseTypeButMissingAuthorizationGrantTypeThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(clientRegistration.getResponseTypes()) - .containsOnly(OAuth2AuthorizationResponseType.CODE.getValue()); - } - - @Test - public void buildWhenAllRequiredClaimsAndEmptyResponseTypeListButMissingAuthorizationGrantTypeThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .responseTypes(List::clear) - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()); - assertThat(clientRegistration.getResponseTypes()) - .containsOnly(OAuth2AuthorizationResponseType.CODE.getValue()); - } - - @Test - public void buildWhenAllRequiredClaimsAndEmptyScopeThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getScope()) - .isNull(); - } - - @Test - public void buildWhenAllRequiredClaimsAndEmptyTokenEndpointAuthMethodThenCreated() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .build(); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); - } - - @Test - public void buildWhenClaimsProvidedThenCreated() { - Map claims = new HashMap<>(); - claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, Collections.singletonList("http://client.example.com")); - claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, Arrays.asList( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - )); - claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, - Collections.singletonList(OAuth2AuthorizationResponseType.CODE.getValue())); - claims.put(OidcClientMetadataClaimNames.SCOPE, "test read"); + public void withClaimsWhenClaimsProvidedThenCreated() { + Instant clientIdIssuedAt = Instant.now(); + Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS); + HashMap claims = new HashMap<>(); + claims.put(OidcClientMetadataClaimNames.CLIENT_ID, "client-id"); + claims.put(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, clientIdIssuedAt); + claims.put(OidcClientMetadataClaimNames.CLIENT_SECRET, "client-secret"); + claims.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt); + claims.put(OidcClientMetadataClaimNames.CLIENT_NAME, "client-name"); + claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, Collections.singletonList("https://client.example.com")); claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, ClientAuthenticationMethod.BASIC.getValue()); + claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, Arrays.asList( + AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())); + claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, Collections.singletonList("code")); + claims.put(OidcClientMetadataClaimNames.SCOPE, Arrays.asList("scope1", "scope2")); + claims.put(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG, SignatureAlgorithm.RS256.getName()); + claims.put("a-claim", "a-value"); OidcClientRegistration clientRegistration = OidcClientRegistration.withClaims(claims).build(); - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("http://client.example.com"); - assertThat(clientRegistration.getGrantTypes()) - .contains( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - ); - assertThat(clientRegistration.getResponseTypes()) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isEqualTo("test read"); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(clientIdIssuedAt); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo("basic"); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getIdTokenSignedResponseAlgorithm()).isEqualTo("RS256"); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); } @Test - public void buildWhenRedirectUriProvidedWithUrlThenCreated() { - Map claims = new HashMap<>(); - claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, Arrays.asList( - url("http://client.example.com"), - url("http://client.example.com/authorized") - ) - ); - claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, Arrays.asList( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - )); - claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, - Collections.singletonList(OAuth2AuthorizationResponseType.CODE.getValue())); - claims.put(OidcClientMetadataClaimNames.SCOPE, "test read"); - claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, ClientAuthenticationMethod.BASIC.getValue()); - - OidcClientRegistration clientRegistration = OidcClientRegistration.withClaims(claims).build(); - - assertThat(clientRegistration.getRedirectUris()) - .contains("http://client.example.com", "http://client.example.com/authorized"); - assertThat(clientRegistration.getGrantTypes()) - .contains( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - ); - assertThat(clientRegistration.getResponseTypes()) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isEqualTo("test read"); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); + public void withClaimsWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OidcClientRegistration.withClaims(null)) + .withMessage("claims cannot be empty"); } @Test - public void withClaimsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OidcClientRegistration.withClaims(null)) - .isInstanceOf(IllegalArgumentException.class); + public void withClaimsWhenEmptyThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OidcClientRegistration.withClaims(Collections.emptyMap())) + .withMessage("claims cannot be empty"); } @Test - public void withClaimsEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> OidcClientRegistration.withClaims(Collections.emptyMap())) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("claims cannot be empty"); + public void buildWhenMissingClientIdThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .clientIdIssuedAt(Instant.now()); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("client_id cannot be null"); } @Test - public void buildWhenNullRedirectUriThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .redirectUris((claims) -> claims.remove(OidcClientMetadataClaimNames.REDIRECT_URIS)); + public void buildWhenClientSecretAndMissingClientIdThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .clientSecret("client-secret"); - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("redirect_uris must not be empty"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("client_id cannot be null"); } @Test - public void buildWhenNullRedirectUriClaimThenThrowIllegalArgumentException() { - Map claims = new HashMap<>(); - claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, null); - OidcClientRegistration.Builder builder = OidcClientRegistration.withClaims(claims); + public void buildWhenClientIdIssuedAtNotInstantThenThrowIllegalArgumentException() { + // @formatter:off + OidcClientRegistration.Builder builder = this.minimalBuilder + .clientId("client-id") + .claim(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT, "clientIdIssuedAt"); + // @formatter:on - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("redirect_uris cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("client_id_issued_at must be of type Instant"); } @Test - public void buildWhenEmptyRedirectUriListThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .redirectUris(List::clear); + public void buildWhenMissingClientSecretThenThrowIllegalArgumentException() { + // @formatter:off + OidcClientRegistration.Builder builder = this.minimalBuilder + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecretExpiresAt(Instant.now().plus(30, ChronoUnit.DAYS)); + // @formatter:on - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("redirect_uris must not be empty"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("client_secret cannot be null"); } @Test - public void buildWhenRedirectUriNotOfTypeListThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .claims(claims -> claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, "http://client.example.com")); + public void buildWhenClientSecretExpiresAtNotInstantThenThrowIllegalArgumentException() { + // @formatter:off + OidcClientRegistration.Builder builder = this.minimalBuilder + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .claim(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, "clientSecretExpiresAt"); + // @formatter:on - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("redirect_uris must be of type list"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("client_secret_expires_at must be of type Instant"); } @Test - public void buildWhenRedirectUriNotUrlThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .redirectUri("not url"); + public void buildWhenMissingRedirectUrisThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = OidcClientRegistration.builder() + .clientName("client-name"); - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("redirect_uri must be a valid URL"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("redirect_uris cannot be null"); } @Test - public void buildWhenResponseTypesNotOfTypeListThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .redirectUri("http://client.example.com") - .claims(claims -> claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, OAuth2AuthorizationResponseType.CODE.getValue())); + public void buildWhenRedirectUrisNotListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = OidcClientRegistration.builder() + .claim(OidcClientMetadataClaimNames.REDIRECT_URIS, "redirectUris"); - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("response_types must be of type List"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("redirect_uris must be of type List"); } @Test - public void buildWhenGrantTypesNotOfTypeListThenThrowIllegalArgumentException() { - OidcClientRegistration.Builder builder = this.clientRegistrationBuilder - .redirectUri("http://client.example.com") - .claims(claims -> claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, AuthorizationGrantType.AUTHORIZATION_CODE.getValue())); + public void buildWhenRedirectUrisEmptyListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = OidcClientRegistration.builder() + .claim(OidcClientMetadataClaimNames.REDIRECT_URIS, Collections.emptyList()); - assertThatThrownBy(builder::build) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("grant_types must be of type List"); + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("redirect_uris cannot be empty"); } - private static URL url(String urlString) { - try { - return new URL(urlString); - } catch (Exception ex) { - throw new IllegalArgumentException("urlString must be a valid URL and valid URI"); - } + @Test + public void buildWhenInvalidRedirectUriThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = OidcClientRegistration.builder() + .redirectUri("invalid-uri"); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("redirect_uri must be a valid URL"); + } + + @Test + public void buildWhenRedirectUrisAddingOrRemovingThenCorrectValues() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .redirectUri("https://client1.example.com") + .redirectUris(redirectUris -> { + redirectUris.clear(); + redirectUris.add("https://client2.example.com"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getRedirectUris()).containsExactly("https://client2.example.com"); + } + + @Test + public void buildWhenGrantTypesNotListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.GRANT_TYPES, "grantTypes"); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("grant_types must be of type List"); + } + + @Test + public void buildWhenGrantTypesEmptyListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.GRANT_TYPES, Collections.emptyList()); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("grant_types cannot be empty"); + } + + @Test + public void buildWhenGrantTypesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .grantType("authorization_code") + .grantTypes(grantTypes -> { + grantTypes.clear(); + grantTypes.add("client_credentials"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getGrantTypes()).containsExactly("client_credentials"); + } + + @Test + public void buildWhenResponseTypesNotListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.RESPONSE_TYPES, "responseTypes"); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("response_types must be of type List"); + } + + @Test + public void buildWhenResponseTypesEmptyListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.RESPONSE_TYPES, Collections.emptyList()); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("response_types cannot be empty"); + } + + @Test + public void buildWhenResponseTypesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .responseType("token") + .responseTypes(responseTypes -> { + responseTypes.clear(); + responseTypes.add("code"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getResponseTypes()).containsExactly("code"); + } + + @Test + public void buildWhenScopesNotListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.SCOPE, "scopes"); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessageStartingWith("scope must be of type List"); + } + + @Test + public void buildWhenScopesEmptyListThenThrowIllegalArgumentException() { + OidcClientRegistration.Builder builder = this.minimalBuilder + .claim(OidcClientMetadataClaimNames.SCOPE, Collections.emptyList()); + + assertThatIllegalArgumentException() + .isThrownBy(builder::build) + .withMessage("scope cannot be empty"); + } + + @Test + public void buildWhenScopesAddingOrRemovingThenCorrectValues() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .scope("should-be-removed") + .scopes(scopes -> { + scopes.clear(); + scopes.add("scope1"); + }) + .build(); + // @formatter:on + + assertThat(clientRegistration.getScopes()).containsExactly("scope1"); + } + + @Test + public void claimWhenNameNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OidcClientRegistration.builder().claim(null, "claim-value")) + .withMessage("name cannot be empty"); + } + + @Test + public void claimWhenValueNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> OidcClientRegistration.builder().claim("claim-name", null)) + .withMessage("value cannot be null"); + } + + @Test + public void claimsWhenRemovingClaimThenNotPresent() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .claim("claim-name", "claim-value") + .claims((claims) -> claims.remove("claim-name")) + .build(); + // @formatter:on + + assertThat(clientRegistration.containsClaim("claim-name")).isFalse(); + } + + @Test + public void claimsWhenAddingClaimThenPresent() { + // @formatter:off + OidcClientRegistration clientRegistration = this.minimalBuilder + .claim("claim-name", "claim-value") + .build(); + // @formatter:on + + assertThat(clientRegistration.containsClaim("claim-name")).isTrue(); } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTest.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTest.java deleted file mode 100644 index de15fa3a..00000000 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTest.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.core.oidc.http.converter; - -import org.junit.Test; -import org.springframework.core.convert.converter.Converter; -import org.springframework.http.HttpStatus; -import org.springframework.http.converter.HttpMessageNotReadableException; -import org.springframework.http.converter.HttpMessageNotWritableException; -import org.springframework.mock.http.MockHttpOutputMessage; -import org.springframework.mock.http.client.MockClientHttpResponse; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; -import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; - -import java.util.Map; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - -/** - * @author Ovidiu Popa - * @since 0.1.1 - */ -public class OidcClientRegistrationHttpMessageConverterTest { - private final OidcClientRegistrationHttpMessageConverter messageConverter = - new OidcClientRegistrationHttpMessageConverter(); - - @Test - public void supportsWhenOidcClientRegistrationThenTrue() { - assertThat(this.messageConverter.supports(OidcClientRegistration.class)).isTrue(); - } - - @Test - public void setClientRegistrationReadConverterWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.messageConverter.setClientRegistrationConverter(null)) - .withMessageContaining("clientRegistrationConverter cannot be null"); - } - - @Test - public void setClientRegistrationWriteConverterWhenNullThenThrowIllegalArgumentException() { - assertThatIllegalArgumentException() - .isThrownBy(() -> this.messageConverter.setClientRegistrationParametersConverter(null)) - .withMessageContaining("clientRegistrationParametersConverter cannot be null"); - } - - @Test - public void readInternalWhenRequiredParametersThenSuccess() { - // @formatter:off - String clientRegistrationResponse = "{\n" - + " \"redirect_uris\": [\n" - + " \"https://client.example.org/callback\"\n" - + " ]\n" - + "}\n"; - // @formatter:on - - MockClientHttpResponse response = new MockClientHttpResponse(clientRegistrationResponse.getBytes(), HttpStatus.OK); - OidcClientRegistration clientRegistration = this.messageConverter - .readInternal(OidcClientRegistration.class, response); - - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("https://client.example.org/callback"); - assertThat(clientRegistration.getGrantTypes()) - .containsOnly( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue() - ); - assertThat(clientRegistration.getResponseTypes()) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isNull(); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); - } - - @Test - public void readInternalWhenValidParametersThenSuccess() { - // @formatter:off - String clientRegistrationResponse = "{\n" - +" \"redirect_uris\": [\n" - + " \"https://client.example.org/callback\"\n" - + " ],\n" - +" \"grant_types\": [\n" - +" \"client_credentials\",\n" - +" \"authorization_code\"\n" - +" ],\n" - +" \"response_types\":[\n" - +" \"code\"\n" - +" ],\n" - +" \"client_name\": \"My Example\",\n" - +" \"scope\": \"read write\",\n" - +" \"token_endpoint_auth_method\": \"basic\"\n" - +"}\n"; - // @formatter:on - MockClientHttpResponse response = new MockClientHttpResponse(clientRegistrationResponse.getBytes(), HttpStatus.OK); - - OidcClientRegistration clientRegistration = this.messageConverter - .readInternal(OidcClientRegistration.class, response); - assertThat(clientRegistration.getRedirectUris()) - .containsOnly("https://client.example.org/callback"); - assertThat(clientRegistration.getGrantTypes()) - .contains( - AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), - AuthorizationGrantType.CLIENT_CREDENTIALS.getValue() - ); - assertThat(clientRegistration.getResponseTypes()) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat(clientRegistration.getScope()) - .isEqualTo("read write"); - assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()) - .isEqualTo(ClientAuthenticationMethod.BASIC.getValue()); - } - - @Test - public void readInternalWhenFailingConverterThenThrowException() { - String errorMessage = "this is not a valid converter"; - this.messageConverter.setClientRegistrationConverter(source -> { - throw new RuntimeException(errorMessage); - }); - MockClientHttpResponse response = new MockClientHttpResponse("{}".getBytes(), HttpStatus.OK); - - assertThatExceptionOfType(HttpMessageNotReadableException.class) - .isThrownBy(() -> this.messageConverter.readInternal(OidcClientRegistration.class, response)) - .withMessageContaining("An error occurred reading the OpenID Client Registration Request") - .withMessageContaining(errorMessage); - } - - @Test - public void readInternalWhenInvalidClientRegistrationThenThrowException() { - String clientRegistrationResponse = "{ \"redirect_uris\": null }"; - MockClientHttpResponse response = new MockClientHttpResponse(clientRegistrationResponse.getBytes(), HttpStatus.OK); - - assertThatExceptionOfType(HttpMessageNotReadableException.class) - .isThrownBy(() -> this.messageConverter.readInternal(OidcClientRegistration.class, response)) - .withMessageContaining("An error occurred reading the OpenID Client Registration Request") - .withMessageContaining("redirect_uris cannot be null"); - } - - @Test - public void writeInternalWhenClientRegistrationThenSuccess() { - OidcClientRegistration clientRegistration = OidcClientRegistration.builder() - .redirectUri("http://client.example.com/callback") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .scope("test read") - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) - .build(); - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - - this.messageConverter.writeInternal(clientRegistration, outputMessage); - String clientRegistrationResponse = outputMessage.getBodyAsString(); - assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"http://client.example.com/callback\"]"); - assertThat(clientRegistrationResponse).contains("\"grant_types\":[\"authorization_code\",\"client_credentials\"]"); - assertThat(clientRegistrationResponse).contains("\"response_types\":[\"code\"]"); - assertThat(clientRegistrationResponse).contains("\"scope\":\"test read\""); - assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_method\":\"basic\""); - } - - @Test - public void writeInternalWhenWriteFailsThenThrowsException() { - String errorMessage = "this is not a valid converter"; - Converter> failingConverter = - source -> { - throw new RuntimeException(errorMessage); - }; - this.messageConverter.setClientRegistrationParametersConverter(failingConverter); - - OidcClientRegistration clientRegistration = - OidcClientRegistration.builder() - .redirectUri("http://client.example.com") - .build(); - - MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); - - assertThatThrownBy(() -> this.messageConverter.writeInternal(clientRegistration, outputMessage)) - .isInstanceOf(HttpMessageNotWritableException.class) - .hasMessageContaining("An error occurred writing the OpenID Client Registration response") - .hasMessageContaining(errorMessage); - } -} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java new file mode 100644 index 00000000..16ce23d6 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java @@ -0,0 +1,250 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.core.oidc.http.converter; + +import java.time.Instant; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.core.convert.converter.Converter; +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageNotReadableException; +import org.springframework.http.converter.HttpMessageNotWritableException; +import org.springframework.mock.http.MockHttpOutputMessage; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests for {@link OidcClientRegistrationHttpMessageConverter} + + * @author Ovidiu Popa + * @author Joe Grandja + * @since 0.1.1 + */ +public class OidcClientRegistrationHttpMessageConverterTests { + private final OidcClientRegistrationHttpMessageConverter messageConverter = new OidcClientRegistrationHttpMessageConverter(); + + @Test + public void supportsWhenOidcClientRegistrationThenTrue() { + assertThat(this.messageConverter.supports(OidcClientRegistration.class)).isTrue(); + } + + @Test + public void setClientRegistrationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.messageConverter.setClientRegistrationConverter(null)) + .withMessageContaining("clientRegistrationConverter cannot be null"); + } + + @Test + public void setClientRegistrationParametersConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.messageConverter.setClientRegistrationParametersConverter(null)) + .withMessageContaining("clientRegistrationParametersConverter cannot be null"); + } + + @Test + public void readInternalWhenRequiredParametersThenSuccess() { + // @formatter:off + String clientRegistrationRequest = "{\n" + + " \"redirect_uris\": [\n" + + " \"https://client.example.com\"\n" + + " ]\n" + + "}\n"; + // @formatter:on + + MockClientHttpResponse response = new MockClientHttpResponse( + clientRegistrationRequest.getBytes(), HttpStatus.OK); + OidcClientRegistration clientRegistration = this.messageConverter + .readInternal(OidcClientRegistration.class, response); + + assertThat(clientRegistration.getClaims()).hasSize(1); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + } + + @Test + public void readInternalWhenValidParametersThenSuccess() { + // @formatter:off + String clientRegistrationRequest = "{\n" + +" \"client_id\": \"client-id\",\n" + +" \"client_id_issued_at\": 1607633867,\n" + +" \"client_secret\": \"client-secret\",\n" + +" \"client_secret_expires_at\": 1607637467,\n" + +" \"client_name\": \"client-name\",\n" + +" \"redirect_uris\": [\n" + + " \"https://client.example.com\"\n" + + " ],\n" + +" \"token_endpoint_auth_method\": \"basic\",\n" + +" \"grant_types\": [\n" + +" \"authorization_code\",\n" + +" \"client_credentials\"\n" + +" ],\n" + +" \"response_types\":[\n" + +" \"code\"\n" + +" ],\n" + +" \"scope\": \"scope1 scope2\",\n" + +" \"id_token_signed_response_alg\": \"RS256\",\n" + +" \"a-claim\": \"a-value\"\n" + +"}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse( + clientRegistrationRequest.getBytes(), HttpStatus.OK); + OidcClientRegistration clientRegistration = this.messageConverter + .readInternal(OidcClientRegistration.class, response); + + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientIdIssuedAt()).isEqualTo(Instant.ofEpochSecond(1607633867L)); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(Instant.ofEpochSecond(1607637467L)); + assertThat(clientRegistration.getClientName()).isEqualTo("client-name"); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo("basic"); + assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials"); + assertThat(clientRegistration.getResponseTypes()).containsOnly("code"); + assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(clientRegistration.getIdTokenSignedResponseAlgorithm()).isEqualTo("RS256"); + assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value"); + } + + @Test + public void readInternalWhenClientSecretNoExpiryThenSuccess() { + // @formatter:off + String clientRegistrationRequest = "{\n" + +" \"client_id\": \"client-id\",\n" + +" \"client_secret\": \"client-secret\",\n" + +" \"client_secret_expires_at\": 0,\n" + +" \"redirect_uris\": [\n" + + " \"https://client.example.com\"\n" + + " ]\n" + +"}\n"; + // @formatter:on + MockClientHttpResponse response = new MockClientHttpResponse( + clientRegistrationRequest.getBytes(), HttpStatus.OK); + OidcClientRegistration clientRegistration = this.messageConverter + .readInternal(OidcClientRegistration.class, response); + + assertThat(clientRegistration.getClaims()).hasSize(3); + assertThat(clientRegistration.getClientId()).isEqualTo("client-id"); + assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret"); + assertThat(clientRegistration.getClientSecretExpiresAt()).isNull(); + assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com"); + } + + @Test + public void readInternalWhenFailingConverterThenThrowException() { + String errorMessage = "this is not a valid converter"; + this.messageConverter.setClientRegistrationConverter(source -> { + throw new RuntimeException(errorMessage); + }); + MockClientHttpResponse response = new MockClientHttpResponse("{}".getBytes(), HttpStatus.OK); + + assertThatExceptionOfType(HttpMessageNotReadableException.class) + .isThrownBy(() -> this.messageConverter.readInternal(OidcClientRegistration.class, response)) + .withMessageContaining("An error occurred reading the OpenID Client Registration") + .withMessageContaining(errorMessage); + } + + @Test + public void writeInternalWhenClientRegistrationThenSuccess() { + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.ofEpochSecond(1607633867)) + .clientSecret("client-secret") + .clientSecretExpiresAt(Instant.ofEpochSecond(1607637467)) + .clientName("client-name") + .redirectUri("https://client.example.com") + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .scope("scope1") + .scope("scope2") + .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) + .claim("a-claim", "a-value") + .build(); + // @formatter:on + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.messageConverter.writeInternal(clientRegistration, outputMessage); + + String clientRegistrationResponse = outputMessage.getBodyAsString(); + assertThat(clientRegistrationResponse).contains("\"client_id\":\"client-id\""); + assertThat(clientRegistrationResponse).contains("\"client_id_issued_at\":1607633867"); + assertThat(clientRegistrationResponse).contains("\"client_secret\":\"client-secret\""); + assertThat(clientRegistrationResponse).contains("\"client_secret_expires_at\":1607637467"); + assertThat(clientRegistrationResponse).contains("\"client_name\":\"client-name\""); + assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"https://client.example.com\"]"); + assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_method\":\"basic\""); + assertThat(clientRegistrationResponse).contains("\"grant_types\":[\"authorization_code\",\"client_credentials\"]"); + assertThat(clientRegistrationResponse).contains("\"response_types\":[\"code\"]"); + assertThat(clientRegistrationResponse).contains("\"scope\":\"scope1 scope2\""); + assertThat(clientRegistrationResponse).contains("\"id_token_signed_response_alg\":\"RS256\""); + assertThat(clientRegistrationResponse).contains("\"a-claim\":\"a-value\""); + } + + @Test + public void writeInternalWhenClientSecretNoExpiryThenSuccess() { + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientId("client-id") + .clientSecret("client-secret") + .redirectUri("https://client.example.com") + .build(); + // @formatter:on + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + this.messageConverter.writeInternal(clientRegistration, outputMessage); + + String clientRegistrationResponse = outputMessage.getBodyAsString(); + assertThat(clientRegistrationResponse).contains("\"client_id\":\"client-id\""); + assertThat(clientRegistrationResponse).contains("\"client_secret\":\"client-secret\""); + assertThat(clientRegistrationResponse).contains("\"client_secret_expires_at\":0"); + assertThat(clientRegistrationResponse).contains("\"redirect_uris\":[\"https://client.example.com\"]"); + } + + @Test + public void writeInternalWhenWriteFailsThenThrowException() { + String errorMessage = "this is not a valid converter"; + Converter> failingConverter = source -> { + throw new RuntimeException(errorMessage); + }; + this.messageConverter.setClientRegistrationParametersConverter(failingConverter); + + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + // @formatter:off + + MockHttpOutputMessage outputMessage = new MockHttpOutputMessage(); + + assertThatThrownBy(() -> this.messageConverter.writeInternal(clientRegistration, outputMessage)) + .isInstanceOf(HttpMessageNotWritableException.class) + .hasMessageContaining("An error occurred writing the OpenID Client Registration") + .hasMessageContaining(errorMessage); + } +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProviderTests.java deleted file mode 100644 index dab9eb1d..00000000 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OidcClientRegistrationAuthenticationProviderTests.java +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Copyright 2020-2021 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.oauth2.server.authorization.authentication; - -import org.junit.Before; -import org.junit.Test; -import org.mockito.ArgumentCaptor; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.oauth2.core.OAuth2AccessToken; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.OAuth2TokenType; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; -import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; -import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; -import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; - -import java.time.Instant; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -/** - * @author Ovidiu Popa - * @since 0.1.1 - */ -public class OidcClientRegistrationAuthenticationProviderTests { - - private OAuth2AuthorizationService authorizationService; - private OidcClientRegistrationAuthenticationProvider authenticationProvider; - - @Before - public void setUp() { - this.authorizationService = mock(OAuth2AuthorizationService.class); - this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider(this.authorizationService); - } - - @Test - public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authorizationService cannot be null"); - } - - @Test - public void supportsWhenTypeJwtAuthenticationTokenThenReturnTrue() { - assertThat(this.authenticationProvider.supports(JwtAuthenticationToken.class)).isTrue(); - } - - @Test - public void authenticateWhenAccessTokenNotFoundThenThrowOAuth2AuthenticationException() { - JwtAuthenticationToken authentication = buildJwtAuthenticationToken("client-registration-token", "SCOPE_client.create"); - - when(authorizationService.findByToken( - eq("client-registration-token"), eq(OAuth2TokenType.ACCESS_TOKEN))) - .thenReturn(null); - - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - - } - - @Test - public void authenticateWhenAccessTokenInvalidatedThenThrowOAuth2AuthenticationException() { - - JwtAuthenticationToken authentication = buildJwtAuthenticationToken("client-registration-token", "SCOPE_client.create"); - - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "client-registration-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000)); - - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() - .token(accessToken, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)) - .build(); - - when(authorizationService.findByToken( - eq("client-registration-token"), eq(OAuth2TokenType.ACCESS_TOKEN))) - .thenReturn(authorization); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - } - - @Test - public void authenticateWhenAccessTokenWithoutClientCreateScopeThenThrowOAuth2AuthenticationException() { - - JwtAuthenticationToken authentication = buildJwtAuthenticationToken("client-registration-token", "SCOPE_scope1"); - - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "client-registration-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000), - new HashSet<>(Collections.singletonList("scope1"))); - - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() - .token(accessToken) - .build(); - - when(authorizationService.findByToken( - eq("client-registration-token"), eq(OAuth2TokenType.ACCESS_TOKEN))) - .thenReturn(authorization); - - assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) - .isInstanceOf(OAuth2AuthenticationException.class) - .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()) - .extracting("errorCode") - .isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); - } - - @Test - public void authenticateWhenValidAccessTokenThenInvalidated() { - JwtAuthenticationToken authentication = buildJwtAuthenticationToken("client-registration-token", "SCOPE_client.create"); - - OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, - "client-registration-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000), - new HashSet<>(Collections.singletonList("client.create"))); - - OAuth2Authorization authorization = TestOAuth2Authorizations.authorization() - .token(accessToken) - .build(); - - when(authorizationService.findByToken( - eq("client-registration-token"), eq(OAuth2TokenType.ACCESS_TOKEN))) - .thenReturn(authorization); - - authenticationProvider.authenticate(authentication); - - ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); - verify(authorizationService).save(authorizationCaptor.capture()); - - OAuth2Authorization capturedAuthorization = authorizationCaptor.getValue(); - - assertThat(capturedAuthorization.getAccessToken()).isNotNull(); - assertThat(capturedAuthorization.getAccessToken().isInvalidated()).isTrue(); - } - - private static JwtAuthenticationToken buildJwtAuthenticationToken(String tokenValue, String... authorities) { - Jwt jwt = Jwt.withTokenValue(tokenValue) - .header("alg", "none") - .claim("sub", "client") - .build(); - List grantedAuthorities = AuthorityUtils.createAuthorityList(authorities); - JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(jwt, grantedAuthorities); - jwtAuthenticationToken.setAuthenticated(true); - return jwtAuthenticationToken; - } -} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java index 3d95966e..e9c3b45c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepositoryTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 the original author or authors. + * Copyright 2020-2021 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,21 +15,24 @@ */ package org.springframework.security.oauth2.server.authorization.client; -import org.junit.Test; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.ClientAuthenticationMethod; - import java.util.Arrays; import java.util.Collections; import java.util.List; +import org.junit.Test; + +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; + import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; /** * Tests for {@link InMemoryRegisteredClientRepository}. * * @author Anoop Garlapati + * @author Ovidiu Popa + * @author Joe Grandja */ public class InMemoryRegisteredClientRepositoryTests { private RegisteredClient registration = TestRegisteredClients.registeredClient().build(); @@ -38,47 +41,70 @@ public class InMemoryRegisteredClientRepositoryTests { @Test public void constructorVarargsRegisteredClientWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - RegisteredClient registration = null; - new InMemoryRegisteredClientRepository(registration); - }).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> { + RegisteredClient registration = null; + new InMemoryRegisteredClientRepository(registration); + }) + .withMessageContaining("registration cannot be null"); } @Test public void constructorListRegisteredClientWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - List registrations = null; - new InMemoryRegisteredClientRepository(registrations); - }).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> { + List registrations = null; + new InMemoryRegisteredClientRepository(registrations); + }) + .withMessageContaining("registrations cannot be empty"); } @Test public void constructorListRegisteredClientWhenEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - List registrations = Collections.emptyList(); - new InMemoryRegisteredClientRepository(registrations); - }).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> { + List registrations = Collections.emptyList(); + new InMemoryRegisteredClientRepository(registrations); + }) + .withMessageContaining("registrations cannot be empty"); } @Test public void constructorListRegisteredClientWhenDuplicateIdThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - RegisteredClient anotherRegistrationWithSameId = TestRegisteredClients.registeredClient2() - .id(this.registration.getId()).build(); - List registrations = Arrays.asList(this.registration, anotherRegistrationWithSameId); - new InMemoryRegisteredClientRepository(registrations); - }).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> { + RegisteredClient anotherRegistrationWithSameId = TestRegisteredClients.registeredClient2() + .id(this.registration.getId()).build(); + List registrations = Arrays.asList(this.registration, anotherRegistrationWithSameId); + new InMemoryRegisteredClientRepository(registrations); + }) + .withMessageStartingWith("Registered client must be unique. Found duplicate identifier:"); } @Test public void constructorListRegisteredClientWhenDuplicateClientIdThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - RegisteredClient anotherRegistrationWithSameClientId = TestRegisteredClients.registeredClient2() - .clientId(this.registration.getClientId()).build(); - List registrations = Arrays.asList(this.registration, - anotherRegistrationWithSameClientId); - new InMemoryRegisteredClientRepository(registrations); - }).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> { + RegisteredClient anotherRegistrationWithSameClientId = TestRegisteredClients.registeredClient2() + .clientId(this.registration.getClientId()).build(); + List registrations = Arrays.asList(this.registration, + anotherRegistrationWithSameClientId); + new InMemoryRegisteredClientRepository(registrations); + }) + .withMessageStartingWith("Registered client must be unique. Found duplicate client identifier:"); + } + + @Test + public void constructorListRegisteredClientWhenDuplicateClientSecretThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> { + RegisteredClient anotherRegistrationWithSameClientSecret = TestRegisteredClients.registeredClient2() + .clientSecret(this.registration.getClientSecret()).build(); + List registrations = Arrays.asList(this.registration, + anotherRegistrationWithSameClientSecret); + new InMemoryRegisteredClientRepository(registrations); + }) + .withMessageStartingWith("Registered client must be unique. Found duplicate client secret for identifier:"); } @Test @@ -95,7 +121,9 @@ public class InMemoryRegisteredClientRepositoryTests { @Test public void findByIdWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.clients.findById(null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.findById(null)) + .withMessageContaining("id cannot be empty"); } @Test @@ -112,79 +140,76 @@ public class InMemoryRegisteredClientRepositoryTests { @Test public void findByClientIdWhenNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.clients.findByClientId(null)).isInstanceOf(IllegalArgumentException.class); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.findByClientId(null)) + .withMessageContaining("clientId cannot be empty"); } @Test - public void saveNullRegisteredClientThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> this.clients.saveClient(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("registeredClient cannot be null"); + public void saveWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(null)) + .withMessageContaining("registeredClient cannot be null"); } @Test - public void saveRegisteredClientThenReturnsSavedRegisteredClientWhenSearchedById() { - RegisteredClient registeredClient = RegisteredClient.withId("new-client") - .clientId("new-client") - .clientSecret("secret") + public void saveWhenExistingIdThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + this.registration.getId(), "client-id-2", "client-secret-2"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId()); + } + + @Test + public void saveWhenExistingClientIdThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + "client-2", this.registration.getClientId(), "client-secret-2"); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId()); + } + + @Test + public void saveWhenExistingClientSecretThenThrowIllegalArgumentException() { + RegisteredClient registeredClient = createRegisteredClient( + "client-2", "client-id-2", this.registration.getClientSecret()); + assertThatIllegalArgumentException() + .isThrownBy(() -> this.clients.save(registeredClient)) + .withMessage("Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId()); + } + + @Test + public void saveWhenSavedAndFindByIdThenFound() { + RegisteredClient registeredClient = createRegisteredClient(); + this.clients.save(registeredClient); + RegisteredClient savedClient = this.clients.findById(registeredClient.getId()); + assertThat(savedClient).isEqualTo(registeredClient); + } + + @Test + public void saveWhenSavedAndFindByClientIdThenFound() { + RegisteredClient registeredClient = createRegisteredClient(); + this.clients.save(registeredClient); + RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId()); + assertThat(savedClient).isEqualTo(registeredClient); + } + + private static RegisteredClient createRegisteredClient() { + return createRegisteredClient("client-2", "client-id-2", "client-secret-2"); + } + + private static RegisteredClient createRegisteredClient(String id, String clientId, String clientSecret) { + // @formatter:off + return RegisteredClient.withId(id) + .clientId(clientId) + .clientSecret(clientSecret) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("https://newclient.com") - .scope("scope1").build(); - - this.clients.saveClient(registeredClient); - - RegisteredClient savedClient = this.clients.findById("new-client"); - - assertThat(savedClient).isNotNull().isEqualTo(registeredClient); + .redirectUri("https://client.example.com") + .scope("scope1") + .build(); + // @formatter:on } - @Test - public void saveRegisteredClientThenReturnsSavedRegisteredClientWhenSearchedByClientId() { - RegisteredClient registeredClient = RegisteredClient.withId("id1") - .clientId("new-client-id") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("https://newclient.com") - .scope("scope1").build(); - - this.clients.saveClient(registeredClient); - - RegisteredClient savedClient = this.clients.findByClientId("new-client-id"); - - assertThat(savedClient).isNotNull().isEqualTo(registeredClient); - } - - @Test - public void saveRegisteredClientWithExistingIdThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - RegisteredClient registeredClient = RegisteredClient.withId("registration-1") - .clientId("new-client") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("https://newclient.com") - .scope("scope1").build(); - - this.clients.saveClient(registeredClient); - }).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Registered client must be unique. Found duplicate identifier"); - } - - @Test - public void saveRegisteredClientWithExistingClientIdThrowIllegalArgumentException() { - assertThatThrownBy(() -> { - RegisteredClient registeredClient = RegisteredClient.withId("new-client") - .clientId("client-1") - .clientSecret("secret") - .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS) - .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) - .redirectUri("https://newclient.com") - .scope("scope1").build(); - - this.clients.saveClient(registeredClient); - }).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("Registered client must be unique. Found duplicate client identifier"); - } } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java index d9a66a12..d913b085 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java @@ -15,6 +15,8 @@ */ package org.springframework.security.oauth2.server.authorization.client; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.Set; import java.util.stream.Collectors; @@ -58,9 +60,14 @@ public class RegisteredClientTests { @Test public void buildWhenAllAttributesProvidedThenAllAttributesAreSet() { + Instant clientIdIssuedAt = Instant.now(); + Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS); RegisteredClient registration = RegisteredClient.withId(ID) .clientId(CLIENT_ID) + .clientIdIssuedAt(clientIdIssuedAt) .clientSecret(CLIENT_SECRET) + .clientSecretExpiresAt(clientSecretExpiresAt) + .clientName("client-name") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .redirectUris(redirectUris -> redirectUris.addAll(REDIRECT_URIS)) @@ -69,7 +76,10 @@ public class RegisteredClientTests { assertThat(registration.getId()).isEqualTo(ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); + assertThat(registration.getClientIdIssuedAt()).isEqualTo(clientIdIssuedAt); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); + assertThat(registration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt); + assertThat(registration.getClientName()).isEqualTo("client-name"); assertThat(registration.getAuthorizationGrantTypes()) .isEqualTo(Collections.singleton(AuthorizationGrantType.AUTHORIZATION_CODE)); assertThat(registration.getClientAuthenticationMethods()).isEqualTo(CLIENT_AUTHENTICATION_METHODS); @@ -325,7 +335,10 @@ public class RegisteredClientTests { assertThat(registration.getId()).isEqualTo(updated.getId()); assertThat(registration.getClientId()).isEqualTo(updated.getClientId()); + assertThat(registration.getClientIdIssuedAt()).isEqualTo(updated.getClientIdIssuedAt()); assertThat(registration.getClientSecret()).isEqualTo(updated.getClientSecret()); + assertThat(registration.getClientSecretExpiresAt()).isEqualTo(updated.getClientSecretExpiresAt()); + assertThat(registration.getClientName()).isEqualTo(updated.getClientName()); assertThat(registration.getClientAuthenticationMethods()).isEqualTo(updated.getClientAuthenticationMethods()); assertThat(registration.getClientAuthenticationMethods()).isNotSameAs(updated.getClientAuthenticationMethods()); assertThat(registration.getAuthorizationGrantTypes()).isEqualTo(updated.getAuthorizationGrantTypes()); @@ -343,10 +356,12 @@ public class RegisteredClientTests { @Test public void buildWhenRegisteredClientValuesOverriddenThenPropagated() { RegisteredClient registration = TestRegisteredClients.registeredClient().build(); + String newName = "client-name"; String newSecret = "new-secret"; String newScope = "new-scope"; String newRedirectUri = "https://another-redirect-uri.com"; RegisteredClient updated = RegisteredClient.from(registration) + .clientName(newName) .clientSecret(newSecret) .scopes(scopes -> { scopes.clear(); @@ -358,6 +373,8 @@ public class RegisteredClientTests { }) .build(); + assertThat(registration.getClientName()).isNotEqualTo(newName); + assertThat(updated.getClientName()).isEqualTo(newName); assertThat(registration.getClientSecret()).isNotEqualTo(newSecret); assertThat(updated.getClientSecret()).isEqualTo(newSecret); assertThat(registration.getScopes()).doesNotContain(newScope); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java index 56b81c5e..e67e793b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ProviderSettingsTests.java @@ -38,7 +38,6 @@ public class ProviderSettingsTests { assertThat(providerSettings.tokenRevocationEndpoint()).isEqualTo("/oauth2/revoke"); assertThat(providerSettings.tokenIntrospectionEndpoint()).isEqualTo("/oauth2/introspect"); assertThat(providerSettings.oidcClientRegistrationEndpoint()).isEqualTo("/connect/register"); - assertThat(providerSettings.isOidClientRegistrationEndpointEnabled()).isFalse(); } @Test @@ -48,8 +47,8 @@ public class ProviderSettingsTests { String jwkSetEndpoint = "/oauth2/v1/jwks"; String tokenRevocationEndpoint = "/oauth2/v1/revoke"; String tokenIntrospectionEndpoint = "/oauth2/v1/introspect"; - String issuer = "https://example.com:9000"; String oidcClientRegistrationEndpoint = "/connect/v1/register"; + String issuer = "https://example.com:9000"; ProviderSettings providerSettings = new ProviderSettings() .issuer(issuer) @@ -59,7 +58,6 @@ public class ProviderSettingsTests { .tokenRevocationEndpoint(tokenRevocationEndpoint) .tokenIntrospectionEndpoint(tokenIntrospectionEndpoint) .tokenRevocationEndpoint(tokenRevocationEndpoint) - .isOidClientRegistrationEndpointEnabled(true) .oidcClientRegistrationEndpoint(oidcClientRegistrationEndpoint); assertThat(providerSettings.issuer()).isEqualTo(issuer); @@ -69,7 +67,6 @@ public class ProviderSettingsTests { assertThat(providerSettings.tokenRevocationEndpoint()).isEqualTo(tokenRevocationEndpoint); assertThat(providerSettings.tokenIntrospectionEndpoint()).isEqualTo(tokenIntrospectionEndpoint); assertThat(providerSettings.oidcClientRegistrationEndpoint()).isEqualTo(oidcClientRegistrationEndpoint); - assertThat(providerSettings.isOidClientRegistrationEndpointEnabled()).isTrue(); } @Test @@ -78,7 +75,7 @@ public class ProviderSettingsTests { .setting("name1", "value1") .settings(settings -> settings.put("name2", "value2")); - assertThat(providerSettings.settings()).hasSize(9); + assertThat(providerSettings.settings()).hasSize(8); assertThat(providerSettings.setting("name1")).isEqualTo("value1"); assertThat(providerSettings.setting("name2")).isEqualTo("value2"); } @@ -126,12 +123,11 @@ public class ProviderSettingsTests { @Test public void oidcClientRegistrationEndpointWhenNullThenThrowIllegalArgumentException() { ProviderSettings settings = new ProviderSettings(); - assertThatThrownBy(() -> settings.oidcClientRegistrationEndpoint(null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("value cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> settings.oidcClientRegistrationEndpoint(null)) + .withMessage("value cannot be null"); } - @Test public void jwksEndpointWhenNullThenThrowIllegalArgumentException() { ProviderSettings settings = new ProviderSettings(); @@ -139,4 +135,5 @@ public class ProviderSettingsTests { .isThrownBy(() -> settings.jwkSetEndpoint(null)) .withMessage("value cannot be null"); } + } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java index 3f6e2331..46ab3fa7 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/TokenSettingsTests.java @@ -15,9 +15,11 @@ */ package org.springframework.security.oauth2.server.authorization.config; +import java.time.Duration; + import org.junit.Test; -import java.time.Duration; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -32,10 +34,11 @@ public class TokenSettingsTests { @Test public void constructorWhenDefaultThenDefaultsAreSet() { TokenSettings tokenSettings = new TokenSettings(); - assertThat(tokenSettings.settings()).hasSize(3); + assertThat(tokenSettings.settings()).hasSize(4); assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(Duration.ofMinutes(5)); assertThat(tokenSettings.reuseRefreshTokens()).isTrue(); assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(Duration.ofMinutes(60)); + assertThat(tokenSettings.idTokenSignatureAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); } @Test @@ -101,17 +104,25 @@ public class TokenSettingsTests { .isEqualTo("refreshTokenTimeToLive must be greater than Duration.ZERO"); } + @Test + public void idTokenSignatureAlgorithmWhenProvidedThenSet() { + SignatureAlgorithm idTokenSignatureAlgorithm = SignatureAlgorithm.RS512; + TokenSettings tokenSettings = new TokenSettings().idTokenSignatureAlgorithm(idTokenSignatureAlgorithm); + assertThat(tokenSettings.idTokenSignatureAlgorithm()).isEqualTo(idTokenSignatureAlgorithm); + } + @Test public void settingWhenCalledThenReturnTokenSettings() { Duration accessTokenTimeToLive = Duration.ofMinutes(10); TokenSettings tokenSettings = new TokenSettings() .setting("name1", "value1") .accessTokenTimeToLive(accessTokenTimeToLive) - .settings(settings -> settings.put("name2", "value2")); - assertThat(tokenSettings.settings()).hasSize(5); + .settings(settings -> settings.put("name2", "value2")); + assertThat(tokenSettings.settings()).hasSize(6); assertThat(tokenSettings.accessTokenTimeToLive()).isEqualTo(accessTokenTimeToLive); assertThat(tokenSettings.reuseRefreshTokens()).isTrue(); assertThat(tokenSettings.refreshTokenTimeToLive()).isEqualTo(Duration.ofMinutes(60)); + assertThat(tokenSettings.idTokenSignatureAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); assertThat(tokenSettings.setting("name1")).isEqualTo("value1"); assertThat(tokenSettings.setting("name2")).isEqualTo("value2"); } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java new file mode 100644 index 00000000..70d2ca59 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java @@ -0,0 +1,314 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.TestJoseHeaders; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; +import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link OidcClientRegistrationAuthenticationProvider}. + * + * @author Ovidiu Popa + * @author Joe Grandja + */ +public class OidcClientRegistrationAuthenticationProviderTests { + private RegisteredClientRepository registeredClientRepository; + private OAuth2AuthorizationService authorizationService; + private OidcClientRegistrationAuthenticationProvider authenticationProvider; + + @Before + public void setUp() { + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.authorizationService = mock(OAuth2AuthorizationService.class); + this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider( + this.registeredClientRepository, this.authorizationService); + } + + @Test + public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(null, this.authorizationService)) + .withMessage("registeredClientRepository cannot be null"); + } + + @Test + public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, null)) + .withMessage("authorizationService cannot be null"); + } + + @Test + public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() { + assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue(); + } + + @Test + public void authenticateWhenPrincipalNotOAuth2TokenAuthenticationTokenThenThrowOAuth2AuthenticationException() { + TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void authenticateWhenPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() { + JwtAuthenticationToken principal = new JwtAuthenticationToken(createJwt()); + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + } + + @Test + public void authenticateWhenAccessTokenNotFoundThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwt(); + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + verify(this.authorizationService).findByToken( + eq(jwt.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwt(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + authorization = OidcAuthenticationProviderUtils.invalidate(authorization, jwtAccessToken); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + verify(this.authorizationService).findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenAccessTokenNotAuthorizedThenThrowOAuth2AuthenticationException() { + Jwt jwt = createJwt(Collections.singleton("unauthorized.scope")); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_unauthorized.scope")); + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com") + .build(); + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + + assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication)) + .isInstanceOf(OAuth2AuthenticationException.class) + .extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode") + .isEqualTo(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); + verify(this.authorizationService).findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + } + + @Test + public void authenticateWhenValidAccessTokenThenReturnClientRegistration() { + Jwt jwt = createJwt(); + OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), jwt.getIssuedAt(), + jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE)); + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization( + registeredClient, jwtAccessToken, jwt.getClaims()).build(); + when(this.authorizationService.findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN))) + .thenReturn(authorization); + + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + principal, clientRegistration); + OidcClientRegistrationAuthenticationToken authenticationResult = + (OidcClientRegistrationAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class); + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + + verify(this.authorizationService).findByToken( + eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)); + verify(this.registeredClientRepository).save(registeredClientCaptor.capture()); + verify(this.authorizationService).save(authorizationCaptor.capture()); + + OAuth2Authorization authorizationResult = authorizationCaptor.getValue(); + assertThat(authorizationResult.getAccessToken().isInvalidated()).isTrue(); + if (authorizationResult.getRefreshToken() != null) { + assertThat(authorizationResult.getRefreshToken().isInvalidated()).isTrue(); + } + + RegisteredClient registeredClientResult = registeredClientCaptor.getValue(); + assertThat(registeredClientResult.getId()).isNotNull(); + assertThat(registeredClientResult.getClientId()).isNotNull(); + assertThat(registeredClientResult.getClientIdIssuedAt()).isNotNull(); + assertThat(registeredClientResult.getClientSecret()).isNotNull(); + assertThat(registeredClientResult.getClientName()).isEqualTo(clientRegistration.getClientName()); + assertThat(registeredClientResult.getClientAuthenticationMethods()).containsExactly(ClientAuthenticationMethod.BASIC); + assertThat(registeredClientResult.getRedirectUris()).containsExactly("https://client.example.com"); + assertThat(registeredClientResult.getAuthorizationGrantTypes()) + .containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS); + assertThat(registeredClientResult.getScopes()).containsExactlyInAnyOrder("scope1", "scope2"); + assertThat(registeredClientResult.getClientSettings().requireProofKey()).isTrue(); + assertThat(registeredClientResult.getClientSettings().requireUserConsent()).isTrue(); + assertThat(registeredClientResult.getTokenSettings().idTokenSignatureAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); + + OidcClientRegistration clientRegistrationResult = authenticationResult.getClientRegistration(); + assertThat(clientRegistrationResult.getClientId()).isEqualTo(registeredClientResult.getClientId()); + assertThat(clientRegistrationResult.getClientIdIssuedAt()).isEqualTo(registeredClientResult.getClientIdIssuedAt()); + assertThat(clientRegistrationResult.getClientSecret()).isEqualTo(registeredClientResult.getClientSecret()); + assertThat(clientRegistrationResult.getClientSecretExpiresAt()).isEqualTo(registeredClientResult.getClientSecretExpiresAt()); + assertThat(clientRegistrationResult.getClientName()).isEqualTo(registeredClientResult.getClientName()); + assertThat(clientRegistrationResult.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(registeredClientResult.getRedirectUris()); + + List grantTypes = new ArrayList<>(); + registeredClientResult.getAuthorizationGrantTypes().forEach(authorizationGrantType -> + grantTypes.add(authorizationGrantType.getValue())); + assertThat(clientRegistrationResult.getGrantTypes()).containsExactlyInAnyOrderElementsOf(grantTypes); + + assertThat(clientRegistrationResult.getResponseTypes()) + .containsExactly(OAuth2AuthorizationResponseType.CODE.getValue()); + assertThat(clientRegistrationResult.getScopes()) + .containsExactlyInAnyOrderElementsOf(registeredClientResult.getScopes()); + assertThat(clientRegistrationResult.getTokenEndpointAuthenticationMethod()) + .isEqualTo(registeredClientResult.getClientAuthenticationMethods().iterator().next().getValue()); + assertThat(clientRegistrationResult.getIdTokenSignedResponseAlgorithm()) + .isEqualTo(registeredClientResult.getTokenSettings().idTokenSignatureAlgorithm().getName()); + } + + private static Jwt createJwt() { + return createJwt(Collections.singleton("client.create")); + } + + private static Jwt createJwt(Set scopes) { + // @formatter:off + JoseHeader joseHeader = TestJoseHeaders.joseHeader() + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet() + .claim(OAuth2ParameterNames.SCOPE, scopes) + .build(); + Jwt jwt = Jwt.withTokenValue("jwt-access-token") + .headers(headers -> headers.putAll(joseHeader.getHeaders())) + .claims(claims -> claims.putAll(jwtClaimsSet.getClaims())) + .build(); + // @formatter:on + return jwt; + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationTokenTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationTokenTests.java new file mode 100644 index 00000000..9c647a96 --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationTokenTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization.oidc.authentication; + +import org.junit.Test; + +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +/** + * Tests for {@link OidcClientRegistrationAuthenticationToken}. + * + * @author Joe Grandja + */ +public class OidcClientRegistrationAuthenticationTokenTests { + private TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "credentials"); + private OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .redirectUri("https://client.example.com").build(); + + @Test + public void constructorWhenPrincipalNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationAuthenticationToken(null, this.clientRegistration)) + .withMessage("principal cannot be null"); + } + + @Test + public void constructorWhenClientRegistrationNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationAuthenticationToken(this.principal, null)) + .withMessage("clientRegistration cannot be null"); + } + + @Test + public void constructorWhenAllValuesProvidedThenCreated() { + OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken( + this.principal, this.clientRegistration); + + assertThat(authentication.getPrincipal()).isEqualTo(this.principal); + assertThat(authentication.getCredentials().toString()).isEmpty(); + assertThat(authentication.getClientRegistration()).isEqualTo(this.clientRegistration); + assertThat(authentication.isAuthenticated()).isEqualTo(this.principal.isAuthenticated()); + } + +} diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java index d8ae1d0e..1a8ff5f9 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java @@ -15,267 +15,259 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import org.junit.After; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; -import org.mockito.AdditionalAnswers; -import org.mockito.ArgumentCaptor; -import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; -import org.springframework.http.converter.HttpMessageConverter; -import org.springframework.mock.http.client.MockClientHttpResponse; -import org.springframework.mock.web.MockHttpServletRequest; -import org.springframework.mock.web.MockHttpServletResponse; -import org.springframework.security.authentication.AuthenticationManager; -import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.oauth2.core.AuthorizationGrantType; -import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import org.springframework.security.oauth2.core.OAuth2Error; -import org.springframework.security.oauth2.core.OAuth2ErrorCodes; -import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; -import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; -import org.springframework.security.oauth2.core.oidc.OidcClientMetadataClaimNames; -import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; -import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; -import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import java.time.Instant; +import java.util.Collections; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.springframework.http.HttpStatus; +import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.mock.http.client.MockClientHttpRequest; +import org.springframework.mock.http.client.MockClientHttpResponse; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.ClientAuthenticationMethod; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.core.OAuth2ErrorCodes; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; +import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; +import org.springframework.security.oauth2.core.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.core.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; +import org.springframework.security.oauth2.jwt.JoseHeader; +import org.springframework.security.oauth2.jwt.Jwt; +import org.springframework.security.oauth2.jwt.JwtClaimsSet; +import org.springframework.security.oauth2.jwt.TestJoseHeaders; +import org.springframework.security.oauth2.jwt.TestJwtClaimsSets; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; +import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; /** - * Tests for {@link OidcClientRegistrationEndpointFilter} + * Tests for {@link OidcClientRegistrationEndpointFilter}. * * @author Ovidiu Popa - * @since 0.1.1 + * @author Joe Grandja */ public class OidcClientRegistrationEndpointFilterTests { - - private static final OidcClientRegistration.Builder OIDC_CLIENT_REGISTRATION = OidcClientRegistration.builder() - .redirectUri("https://localhost:8080/client") - .responseType("code") - .grantType("authorization_code") - .tokenEndpointAuthenticationMethod("basic") - .scope("test"); + private AuthenticationManager authenticationManager; + private OidcClientRegistrationEndpointFilter filter; + private final HttpMessageConverter clientRegistrationHttpMessageConverter = + new OidcClientRegistrationHttpMessageConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); - private static RegisteredClientRepository registeredClientRepository; - private static AuthenticationManager authenticationManager; - - @BeforeClass - public static void init() { - registeredClientRepository = mock(RegisteredClientRepository.class); - authenticationManager = mock(AuthenticationManager.class); - } @Before public void setup() { - reset(registeredClientRepository); - reset(authenticationManager); + this.authenticationManager = mock(AuthenticationManager.class); + this.filter = new OidcClientRegistrationEndpointFilter(this.authenticationManager); } @After - public void tearDown() { + public void cleanup() { SecurityContextHolder.clearContext(); } - @Test - public void constructorWhenRegisteredClientRepositoryNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcClientRegistrationEndpointFilter(null, - authenticationManager)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("registeredClientRepository cannot be null"); - } - @Test public void constructorWhenAuthenticationManagerNullThenThrowIllegalArgumentException() { - - assertThatThrownBy(() -> new OidcClientRegistrationEndpointFilter(registeredClientRepository, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("authenticationManager cannot be null"); + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationEndpointFilter(null)) + .withMessage("authenticationManager cannot be null"); } @Test - public void constructorWhenOidcClientRegistrationUriNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager, null)) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("oidcClientRegistrationUri cannot be empty"); - } - - @Test - public void constructorWhenOidcClientRegistrationUriEmptyThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager, "")) - .isInstanceOf(IllegalArgumentException.class) - .hasMessage("oidcClientRegistrationUri cannot be empty"); + public void constructorWhenClientRegistrationEndpointUriNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> new OidcClientRegistrationEndpointFilter(this.authenticationManager, null)) + .withMessage("clientRegistrationEndpointUri cannot be empty"); } @Test public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception { - OidcClientRegistrationEndpointFilter filter = - new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager); - String requestUri = "/path"; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test public void doFilterWhenClientRegistrationRequestGetThenNotProcessed() throws Exception { - - OidcClientRegistrationEndpointFilter filter = - new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager); - String requestUri = OidcClientRegistrationEndpointFilter.DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } @Test - public void doFilterWhenAuthenticationManagerThrowsOAuth2AuthenticationExceptionThenBadRequest() throws Exception { - - setSecurityContext("client-registration-token", true, "SCOPE_client.create"); - - when(authenticationManager.authenticate(any(JwtAuthenticationToken.class))) - .thenThrow(new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT))); - - OidcClientRegistrationEndpointFilter filter = - new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager); - + public void doFilterWhenClientRegistrationRequestInvalidThenInvalidRequestError() throws Exception { String requestUri = OidcClientRegistrationEndpointFilter.DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); - - request.setContent(convertToByteArray(OIDC_CLIENT_REGISTRATION.build())); - + request.setContent("invalid content".getBytes()); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); verifyNoInteractions(filterChain); assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value()); OAuth2Error error = readError(response); - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_GRANT); + assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST); + assertThat(error.getDescription()).startsWith("OpenID Client Registration Error: "); } @Test - @SuppressWarnings("unchecked") - public void doFilterWhenClientRegistrationRequestThenClientRegistrationResponse() throws Exception { + public void doFilterWhenClientRegistrationRequestInvalidTokenThenUnauthorizedError() throws Exception { + doFilterWhenClientRegistrationRequestInvalidThenError( + OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED); + } - doNothing().when(registeredClientRepository).saveClient(any(RegisteredClient.class)); - when(authenticationManager.authenticate(any(JwtAuthenticationToken.class))).then(AdditionalAnswers.returnsFirstArg()); - setSecurityContext("client-registration-token", true, "SCOPE_client.create"); + @Test + public void doFilterWhenClientRegistrationRequestInsufficientTokenScopeThenForbiddenError() throws Exception { + doFilterWhenClientRegistrationRequestInvalidThenError( + OAuth2ErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN); + } - OidcClientRegistrationEndpointFilter filter = - new OidcClientRegistrationEndpointFilter(registeredClientRepository, authenticationManager); + private void doFilterWhenClientRegistrationRequestInvalidThenError( + String errorCode, HttpStatus status) throws Exception { + Jwt jwt = createJwt(); + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + when(this.authenticationManager.authenticate(any())) + .thenThrow(new OAuth2AuthenticationException(new OAuth2Error(errorCode))); + + // @formatter:off + OidcClientRegistration clientRegistrationRequest = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on String requestUri = OidcClientRegistrationEndpointFilter.DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); + writeClientRegistrationRequest(request, clientRegistrationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); - request.setContent(convertToByteArray(OIDC_CLIENT_REGISTRATION.build())); + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + + assertThat(response.getStatus()).isEqualTo(status.value()); + OAuth2Error error = readError(response); + assertThat(error.getErrorCode()).isEqualTo(errorCode); + } + + @Test + public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception { + // @formatter:off + OidcClientRegistration.Builder clientRegistrationBuilder = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2"); + + OidcClientRegistration clientRegistrationRequest = clientRegistrationBuilder.build(); + + OidcClientRegistration expectedClientRegistrationResponse = clientRegistrationBuilder + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.BASIC.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) + .build(); + // @formatter:on + + Jwt jwt = createJwt(); + JwtAuthenticationToken principal = new JwtAuthenticationToken( + jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create")); + + OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = + new OidcClientRegistrationAuthenticationToken(principal, expectedClientRegistrationResponse); + + when(this.authenticationManager.authenticate(any())).thenReturn(clientRegistrationAuthenticationResult); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + String requestUri = OidcClientRegistrationEndpointFilter.DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); + request.setServletPath(requestUri); + writeClientRegistrationRequest(request, clientRegistrationRequest); MockHttpServletResponse response = new MockHttpServletResponse(); FilterChain filterChain = mock(FilterChain.class); - filter.doFilter(request, response, filterChain); + this.filter.doFilter(request, response, filterChain); verifyNoInteractions(filterChain); - verify(authenticationManager).authenticate(any()); - - ArgumentCaptor registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class); - verify(registeredClientRepository).saveClient(registeredClientCaptor.capture()); - - RegisteredClient registeredClient = registeredClientCaptor.getValue(); - assertThat(response.getStatus()).isEqualTo(HttpStatus.CREATED.value()); - assertThat(response.getContentType()).isEqualTo(MediaType.APPLICATION_JSON_VALUE); - - ObjectMapper objectMapper = new ObjectMapper(); - Map clientRegistrationResponse = objectMapper.readerFor(Map.class) - .readValue(response.getContentAsString()); - - assertThat(clientRegistrationResponse.get(OidcClientMetadataClaimNames.CLIENT_ID)) - .isEqualTo(registeredClient.getClientId()); - assertThat((String) clientRegistrationResponse.get(OidcClientMetadataClaimNames.CLIENT_SECRET)) - .isEqualTo(registeredClient.getClientSecret()); - assertThat((List) clientRegistrationResponse.get(OidcClientMetadataClaimNames.REDIRECT_URIS)) - .containsAll(registeredClient.getRedirectUris()); - assertThat(clientRegistrationResponse.get(OidcClientMetadataClaimNames.CLIENT_ID_ISSUED_AT)) - .isNotNull(); - assertThat(clientRegistrationResponse.get(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT)) - .isEqualTo(0.0); - assertThat((List) clientRegistrationResponse.get(OidcClientMetadataClaimNames.RESPONSE_TYPES)) - .contains(OAuth2AuthorizationResponseType.CODE.getValue()); - assertThat((List) clientRegistrationResponse.get(OidcClientMetadataClaimNames.GRANT_TYPES)) - .containsAll(grantTypes(registeredClient)); - - assertThat(clientRegistrationResponse.get(OidcClientMetadataClaimNames.SCOPE)) - .isEqualTo(String.join(" ", registeredClient.getScopes())); - assertThat(clientRegistrationResponse.get(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD)) - .isEqualTo(registeredClient.getClientAuthenticationMethods().iterator().next().getValue()); - } - - private List grantTypes(RegisteredClient registeredClient) { - return registeredClient.getAuthorizationGrantTypes().stream() - .map(AuthorizationGrantType::getValue) - .collect(Collectors.toList()); - } - - private static void setSecurityContext(String tokenValue, boolean authenticated, String... authorities) { - Jwt jwt = Jwt.withTokenValue(tokenValue) - .header("alg", "none") - .claim("sub", "client") - .build(); - List grantedAuthorities = AuthorityUtils.createAuthorityList(authorities); - JwtAuthenticationToken jwtAuthenticationToken = new JwtAuthenticationToken(jwt, grantedAuthorities); - jwtAuthenticationToken.setAuthenticated(authenticated); - SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); - securityContext.setAuthentication(jwtAuthenticationToken); - SecurityContextHolder.setContext(securityContext); - } - - private static byte[] convertToByteArray(OidcClientRegistration clientRegistration) throws JsonProcessingException { - ObjectMapper objectMapper = new ObjectMapper(); - - return objectMapper - .writerFor(Map.class) - .writeValueAsBytes(clientRegistration.getClaims()); + OidcClientRegistration clientRegistrationResponse = readClientRegistrationResponse(response); + assertThat(clientRegistrationResponse.getClientId()).isEqualTo(expectedClientRegistrationResponse.getClientId()); + assertThat(clientRegistrationResponse.getClientIdIssuedAt()).isBetween( + expectedClientRegistrationResponse.getClientIdIssuedAt().minusSeconds(1), + expectedClientRegistrationResponse.getClientIdIssuedAt().plusSeconds(1)); + assertThat(clientRegistrationResponse.getClientSecret()).isEqualTo(expectedClientRegistrationResponse.getClientSecret()); + assertThat(clientRegistrationResponse.getClientSecretExpiresAt()).isEqualTo(expectedClientRegistrationResponse.getClientSecretExpiresAt()); + assertThat(clientRegistrationResponse.getClientName()).isEqualTo(expectedClientRegistrationResponse.getClientName()); + assertThat(clientRegistrationResponse.getRedirectUris()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getRedirectUris()); + assertThat(clientRegistrationResponse.getGrantTypes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getGrantTypes()); + assertThat(clientRegistrationResponse.getResponseTypes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getResponseTypes()); + assertThat(clientRegistrationResponse.getScopes()) + .containsExactlyInAnyOrderElementsOf(expectedClientRegistrationResponse.getScopes()); + assertThat(clientRegistrationResponse.getTokenEndpointAuthenticationMethod()) + .isEqualTo(expectedClientRegistrationResponse.getTokenEndpointAuthenticationMethod()); + assertThat(clientRegistrationResponse.getIdTokenSignedResponseAlgorithm()) + .isEqualTo(expectedClientRegistrationResponse.getIdTokenSignedResponseAlgorithm()); } private OAuth2Error readError(MockHttpServletResponse response) throws Exception { @@ -283,4 +275,33 @@ public class OidcClientRegistrationEndpointFilterTests { response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse); } + + private void writeClientRegistrationRequest(MockHttpServletRequest request, + OidcClientRegistration clientRegistration) throws Exception { + MockClientHttpRequest httpRequest = new MockClientHttpRequest(); + this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpRequest); + request.setContent(httpRequest.getBodyAsBytes()); + } + + private OidcClientRegistration readClientRegistrationResponse(MockHttpServletResponse response) throws Exception { + MockClientHttpResponse httpResponse = new MockClientHttpResponse( + response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); + return this.clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); + } + + private static Jwt createJwt() { + // @formatter:off + JoseHeader joseHeader = TestJoseHeaders.joseHeader() + .build(); + JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet() + .claim(OAuth2ParameterNames.SCOPE, Collections.singleton("client.create")) + .build(); + Jwt jwt = Jwt.withTokenValue("jwt-access-token") + .headers(headers -> headers.putAll(joseHeader.getHeaders())) + .claims(claims -> claims.putAll(jwtClaimsSet.getClaims())) + .build(); + // @formatter:on + return jwt; + } + } diff --git a/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java b/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java index e5d3ce82..52682412 100644 --- a/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java +++ b/samples/boot/oauth2-integration/authorizationserver/src/main/java/sample/config/AuthorizationServerConfig.java @@ -21,11 +21,6 @@ import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.RSAKey; import com.nimbusds.jose.jwk.source.JWKSource; import com.nimbusds.jose.proc.SecurityContext; -import org.springframework.security.oauth2.core.OAuth2TokenValidator; -import org.springframework.security.oauth2.jwt.Jwt; -import org.springframework.security.oauth2.jwt.JwtDecoder; -import org.springframework.security.oauth2.jwt.JwtValidators; -import org.springframework.security.oauth2.jwt.NimbusJwtDecoder; import sample.jose.Jwks; import org.springframework.context.annotation.Bean; @@ -38,6 +33,7 @@ import org.springframework.security.config.annotation.web.configuration.OAuth2Au import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.server.authorization.client.InMemoryRegisteredClientRepository; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; @@ -86,16 +82,14 @@ public class AuthorizationServerConfig { return (jwkSelector, securityContext) -> jwkSelector.select(jwkSet); } + @Bean + public JwtDecoder jwtDecoder(JWKSource jwkSource) { + return OAuth2AuthorizationServerConfiguration.jwtDecoder(jwkSource); + } + @Bean public ProviderSettings providerSettings() { return new ProviderSettings().issuer("http://auth-server:9000"); } - @Bean - public JwtDecoder jwtDecoder(ProviderSettings providerSettings){ - OAuth2TokenValidator jwtValidator = JwtValidators.createDefaultWithIssuer(providerSettings.issuer()); - NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri("http://auth-server:9000"+providerSettings.jwkSetEndpoint()).build(); - jwtDecoder.setJwtValidator(jwtValidator); - return jwtDecoder; - } }