diff --git a/docs/src/docs/asciidoc/protocol-endpoints.adoc b/docs/src/docs/asciidoc/protocol-endpoints.adoc index 56347119..dc0cf2a9 100644 --- a/docs/src/docs/asciidoc/protocol-endpoints.adoc +++ b/docs/src/docs/asciidoc/protocol-endpoints.adoc @@ -269,9 +269,9 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h == OpenID Connect 1.0 UserInfo Endpoint `OidcUserInfoEndpointConfigurer` provides the ability to customize the https://openid.net/specs/openid-connect-core-1_0.html#UserInfo[OpenID Connect 1.0 UserInfo endpoint]. -It defines extension points that let you customize the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[UserInfo response]. +It defines extension points that let you customize the pre-processing, main processing, and post-processing logic for https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests]. -`OidcUserInfoEndpointConfigurer` provides the following configuration option: +`OidcUserInfoEndpointConfigurer` provides the following configuration options: [source,java] ---- @@ -285,21 +285,37 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h .oidc(oidc -> oidc .userInfoEndpoint(userInfoEndpoint -> - userInfoEndpoint.userInfoMapper(userInfoMapper) <1> + userInfoEndpoint + .userInfoRequestConverter(userInfoRequestConverter) <1> + .userInfoRequestConverters(userInfoRequestConvertersConsumer) <2> + .authenticationProvider(authenticationProvider) <3> + .authenticationProviders(authenticationProvidersConsumer) <4> + .userInfoResponseHandler(userInfoResponseHandler) <5> + .errorResponseHandler(errorResponseHandler) <6> + .userInfoMapper(userInfoMapper) <7> ) ); return http.build(); } ---- -<1> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`. +<1> `userInfoRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract an https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo request] from `HttpServletRequest` to an instance of `OidcUserInfoAuthenticationToken`. +<2> `userInfoRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`. +<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcUserInfoAuthenticationToken`. +<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`. +<5> `userInfoResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcUserInfoAuthenticationToken` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[UserInfo response]. +<6> `errorResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError[UserInfo Error response]. +<7> `userInfoMapper()`: The `Function` used to extract claims from `OidcUserInfoAuthenticationContext` to an instance of `OidcUserInfo`. `OidcUserInfoEndpointConfigurer` configures the `OidcUserInfoEndpointFilter` and registers it with the OAuth2 authorization server `SecurityFilterChain` `@Bean`. `OidcUserInfoEndpointFilter` is the `Filter` that processes https://openid.net/specs/openid-connect-core-1_0.html#UserInfoRequest[UserInfo requests] and returns the https://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse[OidcUserInfo response]. `OidcUserInfoEndpointFilter` is configured with the following defaults: +* `*AuthenticationConverter*` -- An internal implementation that obtains the `Authentication` from the `SecurityContext` and creates an `OidcUserInfoAuthenticationToken` with the principal. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcUserInfoAuthenticationProvider`, which is associated with an internal implementation of `userInfoMapper` that extracts https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims[standard claims] from the https://openid.net/specs/openid-connect-core-1_0.html#IDToken[ID Token] based on the https://openid.net/specs/openid-connect-core-1_0.html#ScopeClaims[scopes requested] during authorization. +* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcUserInfoAuthenticationToken` and returns the `OidcUserInfo` response. +* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. [TIP] You can customize the ID Token by providing an xref:core-model-components.adoc#oauth2-token-customizer[`OAuth2TokenCustomizer`] `@Bean`. @@ -337,8 +353,10 @@ The guide xref:guides/how-to-userinfo.adoc#how-to-userinfo[How-to: Customize the [[oidc-client-registration-endpoint]] == OpenID Connect 1.0 Client Registration Endpoint -`OidcClientRegistrationEndpointConfigurer` configures the https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OpenID Connect 1.0 Client Registration endpoint]. -The following example shows how to enable (disabled by default) the OpenID Connect 1.0 Client Registration endpoint: +`OidcClientRegistrationEndpointConfigurer` provides the ability to customize the https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OpenID Connect 1.0 Client Registration endpoint]. +It defines extension points that let you customize the pre-processing, main processing, and post-processing logic for https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest[Client Registration requests] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadRequest[Client Read requests]. + +`OidcClientRegistrationEndpointConfigurer` provides the following configuration options: [source,java] ---- @@ -351,12 +369,26 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h authorizationServerConfigurer .oidc(oidc -> oidc - .clientRegistrationEndpoint(Customizer.withDefaults()) + .clientRegistrationEndpoint(clientRegistrationEndpoint -> + clientRegistrationEndpoint + .clientRegistrationRequestConverter(clientRegistrationRequestConverter) <1> + .clientRegistrationRequestConverters(clientRegistrationRequestConvertersConsumers) <2> + .authenticationProvider(authenticationProvider) <3> + .authenticationProviders(authenticationProvidersConsumer) <4> + .clientRegistrationResponseHandler(clientRegistrationResponseHandler) <5> + .errorResponseHandler(errorResponseHandler) <6> + ) ); return http.build(); } ---- +<1> `clientRegistrationRequestConverter()`: Adds an `AuthenticationConverter` (_pre-processor_) used when attempting to extract a https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationRequest[Client Registration request] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadRequest[Client Read request] from `HttpServletRequest` to an instance of `OidcClientRegistrationAuthenticationToken`. +<2> `clientRegistrationRequestConverters()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationConverter``'s allowing the ability to add, remove, or customize a specific `AuthenticationConverter`. +<3> `authenticationProvider()`: Adds an `AuthenticationProvider` (_main processor_) used for authenticating the `OidcClientRegistrationAuthenticationToken`. +<4> `authenticationProviders()`: Sets the `Consumer` providing access to the `List` of default and (optionally) added ``AuthenticationProvider``'s allowing the ability to add, remove, or customize a specific `AuthenticationProvider`. +<5> `clientRegistrationResponseHandler()`: The `AuthenticationSuccessHandler` (_post-processor_) used for handling an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationResponse[Client Registration response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadResponse[Client Read response]. +<6> `errorResponseHandler()`: The `AuthenticationFailureHandler` (_post-processor_) used for handling an `OAuth2AuthenticationException` and returning the https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError[Client Registration Error response] or https://openid.net/specs/openid-connect-registration-1_0.html#ReadError[Client Read Error response]. [NOTE] The OpenID Connect 1.0 Client Registration endpoint is disabled by default because many deployments do not require dynamic client registration. @@ -371,6 +403,8 @@ The OpenID Connect 1.0 Client Registration endpoint is disabled by default becau * `*AuthenticationConverter*` -- An `OidcClientRegistrationAuthenticationConverter`. * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OidcClientRegistrationAuthenticationProvider` and `OidcClientConfigurationAuthenticationProvider`. +* `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OidcClientRegistrationAuthenticationToken` and returns the `OidcClientRegistration` response. +* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response. The OpenID Connect 1.0 Client Registration endpoint is an https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration[OAuth2 protected resource], which *REQUIRES* an access token to be sent as a bearer token in the Client Registration (or Client Read) request. diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java index f15ecef9..3a7922dc 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationEndpointConfigurer.java @@ -15,29 +15,53 @@ */ package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletRequest; + import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; +import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.web.access.intercept.AuthorizationFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; /** - * Configurer for OpenID Connect Dynamic Client Registration 1.0 Endpoint. + * Configurer for OpenID Connect 1.0 Dynamic Client Registration Endpoint. * * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.2.0 * @see OidcConfigurer#clientRegistrationEndpoint * @see OidcClientRegistrationEndpointFilter */ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAuth2Configurer { private RequestMatcher requestMatcher; + private final List clientRegistrationRequestConverters = new ArrayList<>(); + private Consumer> clientRegistrationRequestConvertersConsumer = (clientRegistrationRequestConverters) -> {}; + private final List authenticationProviders = new ArrayList<>(); + private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> {}; + private AuthenticationSuccessHandler clientRegistrationResponseHandler; + private AuthenticationFailureHandler errorResponseHandler; /** * Restrict for internal use only. @@ -46,26 +70,108 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut super(objectPostProcessor); } + /** + * Adds an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest} + * to an instance of {@link OidcClientRegistrationAuthenticationToken} used for authenticating the request. + * + * @param clientRegistrationRequestConverter an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverter( + AuthenticationConverter clientRegistrationRequestConverter) { + Assert.notNull(clientRegistrationRequestConverter, "clientRegistrationRequestConverter cannot be null"); + this.clientRegistrationRequestConverters.add(clientRegistrationRequestConverter); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #clientRegistrationRequestConverter(AuthenticationConverter) AuthenticationConverter}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}. + * + * @param clientRegistrationRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationRequestConverters( + Consumer> clientRegistrationRequestConvertersConsumer) { + Assert.notNull(clientRegistrationRequestConvertersConsumer, "clientRegistrationRequestConvertersConsumer cannot be null"); + this.clientRegistrationRequestConvertersConsumer = clientRegistrationRequestConvertersConsumer; + return this; + } + + /** + * Adds an {@link AuthenticationProvider} used for authenticating an {@link OidcClientRegistrationAuthenticationToken}. + * + * @param authenticationProvider an {@link AuthenticationProvider} used for authenticating an {@link OidcClientRegistrationAuthenticationToken} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) { + Assert.notNull(authenticationProvider, "authenticationProvider cannot be null"); + this.authenticationProviders.add(authenticationProvider); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}. + * + * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer authenticationProviders( + Consumer> authenticationProvidersConsumer) { + Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null"); + this.authenticationProvidersConsumer = authenticationProvidersConsumer; + return this; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * and returning the {@link OidcClientRegistration Client Registration Response}. + * + * @param clientRegistrationResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer clientRegistrationResponseHandler(AuthenticationSuccessHandler clientRegistrationResponseHandler) { + this.clientRegistrationResponseHandler = clientRegistrationResponseHandler; + return this; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @return the {@link OidcClientRegistrationEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcClientRegistrationEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) { + this.errorResponseHandler = errorResponseHandler; + return this; + } + @Override void init(HttpSecurity httpSecurity) { AuthorizationServerSettings authorizationServerSettings = OAuth2ConfigurerUtils.getAuthorizationServerSettings(httpSecurity); + String clientRegistrationEndpointUri = authorizationServerSettings.getOidcClientRegistrationEndpoint(); this.requestMatcher = new OrRequestMatcher( - new AntPathRequestMatcher(authorizationServerSettings.getOidcClientRegistrationEndpoint(), HttpMethod.POST.name()), - new AntPathRequestMatcher(authorizationServerSettings.getOidcClientRegistrationEndpoint(), HttpMethod.GET.name()) + new AntPathRequestMatcher(clientRegistrationEndpointUri, HttpMethod.POST.name()), + new AntPathRequestMatcher(clientRegistrationEndpointUri, HttpMethod.GET.name()) ); - OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider = - new OidcClientRegistrationAuthenticationProvider( - OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), - OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity)); - httpSecurity.authenticationProvider(postProcess(oidcClientRegistrationAuthenticationProvider)); - - OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider = - new OidcClientConfigurationAuthenticationProvider( - OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); - httpSecurity.authenticationProvider(postProcess(oidcClientConfigurationAuthenticationProvider)); + List authenticationProviders = createDefaultAuthenticationProviders(httpSecurity); + if (!this.authenticationProviders.isEmpty()) { + authenticationProviders.addAll(0, this.authenticationProviders); + } + this.authenticationProvidersConsumer.accept(authenticationProviders); + authenticationProviders.forEach(authenticationProvider -> + httpSecurity.authenticationProvider(postProcess(authenticationProvider))); } @Override @@ -77,6 +183,20 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut new OidcClientRegistrationEndpointFilter( authenticationManager, authorizationServerSettings.getOidcClientRegistrationEndpoint()); + List authenticationConverters = createDefaultAuthenticationConverters(); + if (!this.clientRegistrationRequestConverters.isEmpty()) { + authenticationConverters.addAll(0, this.clientRegistrationRequestConverters); + } + this.clientRegistrationRequestConvertersConsumer.accept(authenticationConverters); + oidcClientRegistrationEndpointFilter.setAuthenticationConverter( + new DelegatingAuthenticationConverter(authenticationConverters)); + if (this.clientRegistrationResponseHandler != null) { + oidcClientRegistrationEndpointFilter + .setAuthenticationSuccessHandler(this.clientRegistrationResponseHandler); + } + if (this.errorResponseHandler != null) { + oidcClientRegistrationEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); + } httpSecurity.addFilterAfter(postProcess(oidcClientRegistrationEndpointFilter), AuthorizationFilter.class); } @@ -85,4 +205,31 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut return this.requestMatcher; } + private static List createDefaultAuthenticationConverters() { + List authenticationConverters = new ArrayList<>(); + + authenticationConverters.add(new OidcClientRegistrationAuthenticationConverter()); + + return authenticationConverters; + } + + private static List createDefaultAuthenticationProviders(HttpSecurity httpSecurity) { + List authenticationProviders = new ArrayList<>(); + + OidcClientRegistrationAuthenticationProvider oidcClientRegistrationAuthenticationProvider = + new OidcClientRegistrationAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity), + OAuth2ConfigurerUtils.getTokenGenerator(httpSecurity)); + authenticationProviders.add(oidcClientRegistrationAuthenticationProvider); + + OidcClientConfigurationAuthenticationProvider oidcClientConfigurationAuthenticationProvider = + new OidcClientConfigurationAuthenticationProvider( + OAuth2ConfigurerUtils.getRegisteredClientRepository(httpSecurity), + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + authenticationProviders.add(oidcClientConfigurationAuthenticationProvider); + + return authenticationProviders; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java index 9dae96da..0bd08f53 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoEndpointConfigurer.java @@ -15,13 +15,23 @@ */ package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; import java.util.function.Function; +import jakarta.servlet.http.HttpServletRequest; + import org.springframework.http.HttpMethod; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.annotation.ObjectPostProcessor; import org.springframework.security.config.annotation.web.builders.HttpSecurity; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext; @@ -29,21 +39,33 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.web.OidcUserInfoEndpointFilter; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; +import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter; import org.springframework.security.web.access.intercept.AuthorizationFilter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; +import org.springframework.util.Assert; /** * Configurer for OpenID Connect 1.0 UserInfo Endpoint. * * @author Steve Riesenberg + * @author Daniel Garnier-Moiroux * @since 0.2.1 * @see OidcConfigurer#userInfoEndpoint * @see OidcUserInfoEndpointFilter */ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configurer { private RequestMatcher requestMatcher; + private final List userInfoRequestConverters = new ArrayList<>(); + private Consumer> userInfoRequestConvertersConsumer = (userInfoRequestConverters) -> {}; + private final List authenticationProviders = new ArrayList<>(); + private Consumer> authenticationProvidersConsumer = (authenticationProviders) -> {}; + private AuthenticationSuccessHandler userInfoResponseHandler; + private AuthenticationFailureHandler errorResponseHandler; private Function userInfoMapper; /** @@ -53,6 +75,91 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur super(objectPostProcessor); } + /** + * Adds an {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest} + * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the request. + * + * @param userInfoRequestConverter an {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest} + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer userInfoRequestConverter(AuthenticationConverter userInfoRequestConverter) { + Assert.notNull(userInfoRequestConverter, "userInfoRequestConverter cannot be null"); + this.userInfoRequestConverters.add(userInfoRequestConverter); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #userInfoRequestConverter(AuthenticationConverter) AuthenticationConverter}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationConverter}. + * + * @param userInfoRequestConvertersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationConverter}'s + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer userInfoRequestConverters( + Consumer> userInfoRequestConvertersConsumer) { + Assert.notNull(userInfoRequestConvertersConsumer, "userInfoRequestConvertersConsumer cannot be null"); + this.userInfoRequestConvertersConsumer = userInfoRequestConvertersConsumer; + return this; + } + + /** + * Adds an {@link AuthenticationProvider} used for authenticating an {@link OidcUserInfoAuthenticationToken}. + * + * @param authenticationProvider an {@link AuthenticationProvider} used for authenticating an {@link OidcUserInfoAuthenticationToken} + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer authenticationProvider(AuthenticationProvider authenticationProvider) { + Assert.notNull(authenticationProvider, "authenticationProvider cannot be null"); + this.authenticationProviders.add(authenticationProvider); + return this; + } + + /** + * Sets the {@code Consumer} providing access to the {@code List} of default + * and (optionally) added {@link #authenticationProvider(AuthenticationProvider) AuthenticationProvider}'s + * allowing the ability to add, remove, or customize a specific {@link AuthenticationProvider}. + * + * @param authenticationProvidersConsumer the {@code Consumer} providing access to the {@code List} of default and (optionally) added {@link AuthenticationProvider}'s + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer authenticationProviders( + Consumer> authenticationProvidersConsumer) { + Assert.notNull(authenticationProvidersConsumer, "authenticationProvidersConsumer cannot be null"); + this.authenticationProvidersConsumer = authenticationProvidersConsumer; + return this; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} + * and returning the {@link OidcUserInfo UserInfo Response}. + * + * @param userInfoResponseHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer userInfoResponseHandler(AuthenticationSuccessHandler userInfoResponseHandler) { + this.userInfoResponseHandler = userInfoResponseHandler; + return this; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param errorResponseHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration + * @since 0.4.0 + */ + public OidcUserInfoEndpointConfigurer errorResponseHandler(AuthenticationFailureHandler errorResponseHandler) { + this.errorResponseHandler = errorResponseHandler; + return this; + } + /** * Sets the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext} * to an instance of {@link OidcUserInfo} for the UserInfo response. @@ -69,7 +176,8 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur * @param userInfoMapper the {@link Function} used to extract claims from {@link OidcUserInfoAuthenticationContext} to an instance of {@link OidcUserInfo} * @return the {@link OidcUserInfoEndpointConfigurer} for further configuration */ - public OidcUserInfoEndpointConfigurer userInfoMapper(Function userInfoMapper) { + public OidcUserInfoEndpointConfigurer userInfoMapper( + Function userInfoMapper) { this.userInfoMapper = userInfoMapper; return this; } @@ -82,13 +190,13 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.GET.name()), new AntPathRequestMatcher(userInfoEndpointUri, HttpMethod.POST.name())); - OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider = - new OidcUserInfoAuthenticationProvider( - OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); - if (this.userInfoMapper != null) { - oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper); + List authenticationProviders = createDefaultAuthenticationProviders(httpSecurity); + if (!this.authenticationProviders.isEmpty()) { + authenticationProviders.addAll(0, this.authenticationProviders); } - httpSecurity.authenticationProvider(postProcess(oidcUserInfoAuthenticationProvider)); + this.authenticationProvidersConsumer.accept(authenticationProviders); + authenticationProviders.forEach(authenticationProvider -> + httpSecurity.authenticationProvider(postProcess(authenticationProvider))); } @Override @@ -100,6 +208,19 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur new OidcUserInfoEndpointFilter( authenticationManager, authorizationServerSettings.getOidcUserInfoEndpoint()); + List authenticationConverters = createDefaultAuthenticationConverters(); + if (!this.userInfoRequestConverters.isEmpty()) { + authenticationConverters.addAll(0, this.userInfoRequestConverters); + } + this.userInfoRequestConvertersConsumer.accept(authenticationConverters); + oidcUserInfoEndpointFilter.setAuthenticationConverter( + new DelegatingAuthenticationConverter(authenticationConverters)); + if (this.userInfoResponseHandler != null) { + oidcUserInfoEndpointFilter.setAuthenticationSuccessHandler(this.userInfoResponseHandler); + } + if (this.errorResponseHandler != null) { + oidcUserInfoEndpointFilter.setAuthenticationFailureHandler(this.errorResponseHandler); + } httpSecurity.addFilterAfter(postProcess(oidcUserInfoEndpointFilter), AuthorizationFilter.class); } @@ -108,4 +229,31 @@ public final class OidcUserInfoEndpointConfigurer extends AbstractOAuth2Configur return this.requestMatcher; } + private static List createDefaultAuthenticationConverters() { + List authenticationConverters = new ArrayList<>(); + + authenticationConverters.add( + (request) -> { + Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); + return new OidcUserInfoAuthenticationToken(authentication); + } + ); + + return authenticationConverters; + } + + private List createDefaultAuthenticationProviders(HttpSecurity httpSecurity) { + List authenticationProviders = new ArrayList<>(); + + OidcUserInfoAuthenticationProvider oidcUserInfoAuthenticationProvider = + new OidcUserInfoAuthenticationProvider( + OAuth2ConfigurerUtils.getAuthorizationService(httpSecurity)); + if (this.userInfoMapper != null) { + oidcUserInfoAuthenticationProvider.setUserInfoMapper(this.userInfoMapper); + } + authenticationProviders.add(oidcUserInfoAuthenticationProvider); + + return authenticationProviders; + } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java index 575e4593..5ce732b7 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProvider.java @@ -46,6 +46,7 @@ import org.springframework.util.StringUtils; * @since 0.4.0 * @see RegisteredClientRepository * @see OAuth2AuthorizationService + * @see OidcClientRegistrationAuthenticationToken * @see OidcClientRegistrationAuthenticationProvider * @see 4. Client Configuration Endpoint */ @@ -67,7 +68,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth Assert.notNull(authorizationService, "authorizationService cannot be null"); this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; - this.clientRegistrationConverter = new OidcClientRegistrationConverter(); + this.clientRegistrationConverter = new RegisteredClientOidcClientRegistrationConverter(); } @Override diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java index e7ec6233..81e88c79 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java @@ -74,6 +74,7 @@ import org.springframework.util.StringUtils; * @see RegisteredClientRepository * @see OAuth2AuthorizationService * @see OAuth2TokenGenerator + * @see OidcClientRegistrationAuthenticationToken * @see OidcClientConfigurationAuthenticationProvider * @see 3. Client Registration Endpoint */ @@ -84,7 +85,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe private final OAuth2AuthorizationService authorizationService; private final OAuth2TokenGenerator tokenGenerator; private final Converter clientRegistrationConverter; - private final Converter registeredClientConverter; + private Converter registeredClientConverter; /** * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters. @@ -102,8 +103,8 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe this.registeredClientRepository = registeredClientRepository; this.authorizationService = authorizationService; this.tokenGenerator = tokenGenerator; - this.clientRegistrationConverter = new OidcClientRegistrationConverter(); - this.registeredClientConverter = new RegisteredClientConverter(); + this.clientRegistrationConverter = new RegisteredClientOidcClientRegistrationConverter(); + this.registeredClientConverter = new OidcClientRegistrationRegisteredClientConverter(); } @Override @@ -147,6 +148,17 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe return OidcClientRegistrationAuthenticationToken.class.isAssignableFrom(authentication); } + /** + * Sets the {@link Converter} used for converting an {@link OidcClientRegistration} to a {@link RegisteredClient}. + * + * @param registeredClientConverter the {@link Converter} used for converting an {@link OidcClientRegistration} to a {@link RegisteredClient} + * @since 0.4.0 + */ + public void setRegisteredClientConverter(Converter registeredClientConverter) { + Assert.notNull(registeredClientConverter, "registeredClientConverter cannot be null"); + this.registeredClientConverter = registeredClientConverter; + } + private OidcClientRegistrationAuthenticationToken registerClient(OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication, OAuth2Authorization authorization) { @@ -293,7 +305,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe throw new OAuth2AuthenticationException(error); } - private static final class RegisteredClientConverter implements Converter { + private static final class OidcClientRegistrationRegisteredClientConverter implements Converter { private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator( Base64.getUrlEncoder().withoutPadding(), 32); private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator( diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationConverter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java similarity index 96% rename from oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationConverter.java rename to oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java index b7e16d4e..75aa17c9 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationConverter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/RegisteredClientOidcClientRegistrationConverter.java @@ -31,7 +31,7 @@ import org.springframework.web.util.UriComponentsBuilder; * @author Joe Grandja * @since 0.4.0 */ -final class OidcClientRegistrationConverter implements Converter { +final class RegisteredClientOidcClientRegistrationConverter implements Converter { @Override public OidcClientRegistration convert(RegisteredClient registeredClient) { diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java index d7e89969..5167f797 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilter.java @@ -27,6 +27,8 @@ import org.springframework.http.HttpStatus; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; @@ -40,6 +42,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AndRequestMatcher; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; @@ -53,6 +57,7 @@ import org.springframework.web.filter.OncePerRequestFilter; * * @author Ovidiu Popa * @author Joe Grandja + * @author Daniel Garnier-Moiroux * @since 0.1.1 * @see OidcClientRegistration * @see OidcClientRegistrationAuthenticationConverter @@ -73,7 +78,9 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi new OidcClientRegistrationHttpMessageConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); - private AuthenticationConverter authenticationConverter; + private AuthenticationConverter authenticationConverter = new OidcClientRegistrationAuthenticationConverter(); + private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendClientRegistrationResponse; + private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; /** * Constructs an {@code OidcClientRegistrationEndpointFilter} using the provided parameters. @@ -99,7 +106,6 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi new AntPathRequestMatcher( clientRegistrationEndpointUri, HttpMethod.POST.name()), createClientConfigurationMatcher(clientRegistrationEndpointUri)); - this.authenticationConverter = new OidcClientRegistrationAuthenticationConverter(); } private static RequestMatcher createClientConfigurationMatcher(String clientRegistrationEndpointUri) { @@ -124,39 +130,78 @@ public final class OidcClientRegistrationEndpointFilter extends OncePerRequestFi } try { - OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = - (OidcClientRegistrationAuthenticationToken) this.authenticationConverter.convert(request); + Authentication clientRegistrationAuthentication = this.authenticationConverter.convert(request); - OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = - (OidcClientRegistrationAuthenticationToken) this.authenticationManager.authenticate(clientRegistrationAuthentication); - - HttpStatus httpStatus = HttpStatus.OK; - if (clientRegistrationAuthentication.getClientRegistration() != null) { - httpStatus = HttpStatus.CREATED; - } - - sendClientRegistrationResponse(response, httpStatus, clientRegistrationAuthenticationResult.getClientRegistration()); + Authentication clientRegistrationAuthenticationResult = + this.authenticationManager.authenticate(clientRegistrationAuthentication); + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult); } catch (OAuth2AuthenticationException ex) { - sendErrorResponse(response, ex.getError()); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); } catch (Exception ex) { OAuth2Error error = new OAuth2Error( OAuth2ErrorCodes.INVALID_REQUEST, - "OpenID Client Registration Error: " + ex.getMessage(), + "OpenID Connect 1.0 Client Registration Error: " + ex.getMessage(), "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError"); - sendErrorResponse(response, error); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(error)); } finally { SecurityContextHolder.clearContext(); } } - private void sendClientRegistrationResponse(HttpServletResponse response, HttpStatus httpStatus, OidcClientRegistration clientRegistration) throws IOException { + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest} + * to an instance of {@link OidcClientRegistrationAuthenticationToken} used for authenticating the request. + * + * @param authenticationConverter an {@link AuthenticationConverter} used when attempting to extract a Client Registration Request from {@link HttpServletRequest} + * @since 0.4.0 + */ + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; + } + + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * and returning the {@link OidcClientRegistration Client Registration Response}. + * + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcClientRegistrationAuthenticationToken} + * @see 0.4.0 + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @since 0.4.0 + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + private void sendClientRegistrationResponse(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException { + OidcClientRegistration clientRegistration = ((OidcClientRegistrationAuthenticationToken) authentication) + .getClientRegistration(); ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - httpResponse.setStatusCode(httpStatus); + if (HttpMethod.POST.name().equals(request.getMethod())) { + httpResponse.setStatusCode(HttpStatus.CREATED); + } else { + httpResponse.setStatusCode(HttpStatus.OK); + } this.clientRegistrationHttpMessageConverter.write(clientRegistration, null, httpResponse); } - private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + AuthenticationException authenticationException) throws IOException { + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); HttpStatus httpStatus = HttpStatus.BAD_REQUEST; if (OAuth2ErrorCodes.INVALID_TOKEN.equals(error.getErrorCode())) { httpStatus = HttpStatus.UNAUTHORIZED; diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java index f2a0746c..32f50ff8 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilter.java @@ -28,14 +28,19 @@ import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; +import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter; import org.springframework.security.oauth2.core.oidc.OidcUserInfo; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcUserInfoHttpMessageConverter; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.security.web.util.matcher.OrRequestMatcher; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -47,8 +52,10 @@ import org.springframework.web.filter.OncePerRequestFilter; * * @author Ido Salomon * @author Steve Riesenberg + * @author Daniel Garnier-Moiroux * @since 0.2.1 * @see OidcUserInfo + * @see OidcUserInfoAuthenticationProvider * @see 5.3. UserInfo Endpoint */ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { @@ -60,11 +67,13 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { private final AuthenticationManager authenticationManager; private final RequestMatcher userInfoEndpointMatcher; - private final HttpMessageConverter userInfoHttpMessageConverter = new OidcUserInfoHttpMessageConverter(); private final HttpMessageConverter errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter(); + private AuthenticationConverter authenticationConverter = this::createAuthentication; + private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendUserInfoResponse; + private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; /** * Constructs an {@code OidcUserInfoEndpointFilter} using the provided parameters. @@ -100,34 +109,77 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { } try { - Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + Authentication userInfoAuthentication = this.authenticationConverter.convert(request); - OidcUserInfoAuthenticationToken userInfoAuthentication = new OidcUserInfoAuthenticationToken(principal); - - OidcUserInfoAuthenticationToken userInfoAuthenticationResult = - (OidcUserInfoAuthenticationToken) this.authenticationManager.authenticate(userInfoAuthentication); - - sendUserInfoResponse(response, userInfoAuthenticationResult.getUserInfo()); + Authentication userInfoAuthenticationResult = + this.authenticationManager.authenticate(userInfoAuthentication); + this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, userInfoAuthenticationResult); } catch (OAuth2AuthenticationException ex) { - sendErrorResponse(response, ex.getError()); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex); } catch (Exception ex) { OAuth2Error error = new OAuth2Error( OAuth2ErrorCodes.INVALID_REQUEST, "OpenID Connect 1.0 UserInfo Error: " + ex.getMessage(), "https://openid.net/specs/openid-connect-core-1_0.html#UserInfoError"); - sendErrorResponse(response, error); + this.authenticationFailureHandler.onAuthenticationFailure(request, response, + new OAuth2AuthenticationException(error)); } finally { SecurityContextHolder.clearContext(); } } - private void sendUserInfoResponse(HttpServletResponse response, OidcUserInfo userInfo) throws IOException { - ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); - this.userInfoHttpMessageConverter.write(userInfo, null, httpResponse); + /** + * Sets the {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest} + * to an instance of {@link OidcUserInfoAuthenticationToken} used for authenticating the request. + * + * @param authenticationConverter the {@link AuthenticationConverter} used when attempting to extract an UserInfo Request from {@link HttpServletRequest} + * @since 0.4.0 + */ + public void setAuthenticationConverter(AuthenticationConverter authenticationConverter) { + Assert.notNull(authenticationConverter, "authenticationConverter cannot be null"); + this.authenticationConverter = authenticationConverter; } - private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException { + /** + * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} + * and returning the {@link OidcUserInfo UserInfo Response}. + * + * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OidcUserInfoAuthenticationToken} + * @since 0.4.0 + */ + public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) { + Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null"); + this.authenticationSuccessHandler = authenticationSuccessHandler; + } + + /** + * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * and returning the {@link OAuth2Error Error Response}. + * + * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException} + * @since 0.4.0 + */ + public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) { + Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null"); + this.authenticationFailureHandler = authenticationFailureHandler; + } + + private Authentication createAuthentication(HttpServletRequest request) { + Authentication principal = SecurityContextHolder.getContext().getAuthentication(); + return new OidcUserInfoAuthenticationToken(principal); + } + + private void sendUserInfoResponse(HttpServletRequest request, HttpServletResponse response, + Authentication authentication) throws IOException { + OidcUserInfoAuthenticationToken userInfoAuthenticationToken = (OidcUserInfoAuthenticationToken) authentication; + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + this.userInfoHttpMessageConverter.write(userInfoAuthenticationToken.getUserInfo(), null, httpResponse); + } + + private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, + AuthenticationException authenticationException) throws IOException { + OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError(); HttpStatus httpStatus = HttpStatus.BAD_REQUEST; if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_TOKEN)) { httpStatus = HttpStatus.UNAUTHORIZED; @@ -138,4 +190,5 @@ public final class OidcUserInfoEndpointFilter extends OncePerRequestFilter { httpResponse.setStatusCode(httpStatus); this.errorHttpResponseConverter.write(error, null, httpResponse); } + } diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 9d81d170..52e84d2b 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -18,7 +18,9 @@ package org.springframework.security.oauth2.server.authorization.web; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import jakarta.servlet.FilterChain; @@ -287,10 +289,16 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte UriComponentsBuilder uriBuilder = UriComponentsBuilder .fromUriString(authorizationCodeRequestAuthentication.getRedirectUri()) .queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue()); + String redirectUri; if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); + Map queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + redirectUri = uriBuilder.build(queryParams).toString(); + } else { + redirectUri = uriBuilder.toUriString(); } - this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); + this.redirectStrategy.sendRedirect(request, response, redirectUri); } private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response, @@ -317,10 +325,16 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte if (StringUtils.hasText(error.getUri())) { uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri()); } + String redirectUri; if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) { - uriBuilder.queryParam(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}"); + Map queryParams = new HashMap<>(); + queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState()); + redirectUri = uriBuilder.build(queryParams).toString(); + } else { + redirectUri = uriBuilder.toUriString(); } - this.redirectStrategy.sendRedirect(request, response, uriBuilder.toUriString()); + this.redirectStrategy.sendRedirect(request, response, redirectUri); } /** diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java index 2d95add5..fce02bd8 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java @@ -84,7 +84,7 @@ public class TestOAuth2Authorizations { .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizedScopes(authorizationRequest.getScopes()) .token(authorizationCode) - .attribute(OAuth2ParameterNames.STATE, "state") + .attribute(OAuth2ParameterNames.STATE, "consent-state") .attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest) .attribute(Principal.class.getName(), new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B")); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java index bd2a5e18..68d22ed3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java @@ -70,6 +70,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.PkceParameterNames; @@ -160,6 +161,9 @@ public class OAuth2AuthorizationCodeGrantTests { private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; private static final String AUTHORITIES_CLAIM = "authorities"; + private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q"; + private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q"; + private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE); private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE); @@ -291,7 +295,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -383,7 +387,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -427,7 +431,7 @@ public class OAuth2AuthorizationCodeGrantTests { .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -503,19 +507,27 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .principalName("user") .build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationRequest updatedAuthorizationRequest = + OAuth2AuthorizationRequest.from(authorizationRequest) + .state(STATE_URL_UNENCODED) + .build(); + authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest) + .build(); this.authorizationService.save(authorization); MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI) .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param(OAuth2ParameterNames.SCOPE, "message.read") .param(OAuth2ParameterNames.SCOPE, "message.write") - .param(OAuth2ParameterNames.STATE, "state") + .param(OAuth2ParameterNames.STATE, authorization.getAttribute(OAuth2ParameterNames.STATE)) .with(user("user"))) .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -583,18 +595,26 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient) .build(); + OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName()); + OAuth2AuthorizationRequest updatedAuthorizationRequest = + OAuth2AuthorizationRequest.from(authorizationRequest) + .state(STATE_URL_UNENCODED) + .build(); + authorization = OAuth2Authorization.from(authorization) + .attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest) + .build(); this.authorizationService.save(authorization); MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI) .param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()) .param("authority", "authority-1 authority-2") - .param(OAuth2ParameterNames.STATE, "state") + .param(OAuth2ParameterNames.STATE, authorization.getAttribute(OAuth2ParameterNames.STATE)) .with(user("principal"))) .andExpect(status().is3xxRedirection()) .andReturn(); String redirectedUrl = mvcResult.getResponse().getRedirectedUrl(); - assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=state"); + assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED); String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code"); OAuth2Authorization authorizationCodeAuthorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE); @@ -632,7 +652,7 @@ public class OAuth2AuthorizationCodeGrantTests { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = new OAuth2AuthorizationCodeRequestAuthenticationToken( "https://provider.com/oauth2/authorize", registeredClient.getClientId(), principal, authorizationCode, - registeredClient.getRedirectUris().iterator().next(), "state", registeredClient.getScopes()); + registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes()); when(authorizationRequestConverter.convert(any())).thenReturn(authorizationCodeRequestAuthenticationResult); when(authorizationRequestAuthenticationProvider.supports(eq(OAuth2AuthorizationCodeRequestAuthenticationToken.class))).thenReturn(true); when(authorizationRequestAuthenticationProvider.authenticate(any())).thenReturn(authorizationCodeRequestAuthenticationResult); @@ -718,7 +738,7 @@ public class OAuth2AuthorizationCodeGrantTests { parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next()); parameters.set(OAuth2ParameterNames.SCOPE, StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " ")); - parameters.set(OAuth2ParameterNames.STATE, "state"); + parameters.set(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED); return parameters; } diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java index e2873adb..0cb8a86b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcClientRegistrationTests.java @@ -18,6 +18,10 @@ package org.springframework.security.oauth2.server.authorization.config.annotati import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +import jakarta.servlet.http.HttpServletResponse; import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -30,6 +34,7 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -38,6 +43,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.http.converter.HttpMessageConverter; +import org.springframework.http.server.ServletServerHttpResponse; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; @@ -46,6 +52,7 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; import org.springframework.mock.http.MockHttpOutputMessage; import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -55,6 +62,7 @@ import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; @@ -77,11 +85,18 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientConfigurationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationProvider; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; +import org.springframework.security.oauth2.server.authorization.oidc.web.authentication.OidcClientRegistrationAuthenticationConverter; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; import org.springframework.security.oauth2.server.authorization.test.SpringTestRule; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.util.matcher.RequestMatcher; import org.springframework.test.web.servlet.MockMvc; import org.springframework.test.web.servlet.MvcResult; @@ -89,6 +104,14 @@ import org.springframework.web.util.UriComponentsBuilder; import static org.assertj.core.api.Assertions.assertThat; import static org.hamcrest.CoreMatchers.containsString; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; +import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header; @@ -128,6 +151,18 @@ public class OidcClientRegistrationTests { @Autowired private AuthorizationServerSettings authorizationServerSettings; + private static AuthenticationConverter authenticationConverter; + + private static Consumer> authenticationConvertersConsumer; + + private static AuthenticationProvider authenticationProvider; + + private static Consumer> authenticationProvidersConsumer; + + private static AuthenticationSuccessHandler authenticationSuccessHandler; + + private static AuthenticationFailureHandler authenticationFailureHandler; + private MockWebServer server; private String clientJwkSetUrl; @@ -145,6 +180,12 @@ public class OidcClientRegistrationTests { .addScript("org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql") .addScript("org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql") .build(); + authenticationConverter = mock(AuthenticationConverter.class); + authenticationConvertersConsumer = mock(Consumer.class); + authenticationProvider = mock(AuthenticationProvider.class); + authenticationProvidersConsumer = mock(Consumer.class); + authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + authenticationFailureHandler = mock(AuthenticationFailureHandler.class); } @Before @@ -158,6 +199,7 @@ public class OidcClientRegistrationTests { .setBody(clientJwkSet.toString()); // @formatter:on this.server.enqueue(response); + when(authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).thenReturn(true); } @After @@ -165,6 +207,12 @@ public class OidcClientRegistrationTests { this.server.shutdown(); jdbcOperations.update("truncate table oauth2_authorization"); jdbcOperations.update("truncate table oauth2_registered_client"); + reset(authenticationConverter); + reset(authenticationConvertersConsumer); + reset(authenticationProvider); + reset(authenticationProvidersConsumer); + reset(authenticationSuccessHandler); + reset(authenticationFailureHandler); } @AfterClass @@ -261,6 +309,67 @@ public class OidcClientRegistrationTests { assertThat(clientConfigurationResponse.getRegistrationAccessToken()).isNull(); } + @Test + public void requestWhenClientRegistrationEndpointCustomizedThenUsed() throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .scope("scope1") + .scope("scope2") + .build(); + // @formatter:on + + doAnswer(invocation -> { + HttpServletResponse response = invocation.getArgument(1, HttpServletResponse.class); + ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response); + httpResponse.setStatusCode(HttpStatus.CREATED); + new OidcClientRegistrationHttpMessageConverter().write(clientRegistration, null, httpResponse); + return null; + }).when(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + + registerClient(clientRegistration); + + verify(authenticationConverter).convert(any()); + ArgumentCaptor> authenticationConvertersCaptor = + ArgumentCaptor.forClass(List.class); + verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture()); + List authenticationConverters = authenticationConvertersCaptor.getValue(); + assertThat(authenticationConverters).hasSize(2) + .allMatch(converter -> converter == authenticationConverter + || converter instanceof OidcClientRegistrationAuthenticationConverter); + + verify(authenticationProvider).authenticate(any()); + ArgumentCaptor> authenticationProvidersCaptor = + ArgumentCaptor.forClass(List.class); + verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture()); + List authenticationProviders = authenticationProvidersCaptor.getValue(); + assertThat(authenticationProviders).hasSize(3) + .allMatch(provider -> provider == authenticationProvider + || provider instanceof OidcClientRegistrationAuthenticationProvider + || provider instanceof OidcClientConfigurationAuthenticationProvider); + + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + verifyNoInteractions(authenticationFailureHandler); + } + + @Test + public void requestWhenClientRegistrationEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception { + this.spring.register(CustomClientRegistrationConfiguration.class).autowire(); + + when(authenticationProvider.authenticate(any())).thenThrow(new OAuth2AuthenticationException("error")); + + this.mvc.perform(get(DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI) + .param(OAuth2ParameterNames.CLIENT_ID, "invalid").with(jwt())); + + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); + verifyNoInteractions(authenticationSuccessHandler); + } + private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception { // ***** (1) Obtain the "initial" access token used for registering the client @@ -353,6 +462,44 @@ public class OidcClientRegistrationTests { return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); } + @EnableWebSecurity + @Configuration(proxyBeanMethods = false) + static class CustomClientRegistrationConfiguration extends AuthorizationServerConfiguration { + + // @formatter:off + @Bean + @Override + public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception { + OAuth2AuthorizationServerConfigurer authorizationServerConfigurer = + new OAuth2AuthorizationServerConfigurer(); + authorizationServerConfigurer + .oidc(oidc -> + oidc + .clientRegistrationEndpoint(clientRegistration -> + clientRegistration + .clientRegistrationRequestConverter(authenticationConverter) + .clientRegistrationRequestConverters(authenticationConvertersConsumer) + .authenticationProvider(authenticationProvider) + .authenticationProviders(authenticationProvidersConsumer) + .clientRegistrationResponseHandler(authenticationSuccessHandler) + .errorResponseHandler(authenticationFailureHandler) + ) + ); + RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher(); + + http + .securityMatcher(endpointsMatcher) + .authorizeHttpRequests(authorize -> + authorize.anyRequest().authenticated() + ) + .csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher)) + .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt) + .apply(authorizationServerConfigurer); + return http.build(); + } + // @formatter:on + } + @EnableWebSecurity @Configuration(proxyBeanMethods = false) static class AuthorizationServerConfiguration { diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java index 083f40dc..c91ebc18 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcUserInfoTests.java @@ -19,9 +19,13 @@ import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.function.Consumer; import java.util.function.Function; +import jakarta.servlet.http.HttpServletResponse; + import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.source.ImmutableJWKSet; import com.nimbusds.jose.jwk.source.JWKSource; @@ -30,11 +34,14 @@ import org.junit.Before; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.config.Customizer; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; @@ -62,11 +69,15 @@ import org.springframework.security.oauth2.server.authorization.client.Registere import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; import org.springframework.security.oauth2.server.authorization.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationContext; +import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationProvider; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings; import org.springframework.security.oauth2.server.authorization.test.SpringTestRule; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; import org.springframework.security.web.SecurityFilterChain; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.context.HttpSessionSecurityContextRepository; import org.springframework.security.web.context.SecurityContextRepository; import org.springframework.security.web.util.matcher.RequestMatcher; @@ -75,8 +86,15 @@ import org.springframework.test.web.servlet.MvcResult; import org.springframework.test.web.servlet.ResultMatcher; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get; import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; @@ -100,17 +118,48 @@ public class OidcUserInfoTests { @Autowired private JwtEncoder jwtEncoder; + @Autowired + private JwtDecoder jwtDecoder; + @Autowired private OAuth2AuthorizationService authorizationService; + private static AuthenticationConverter authenticationConverter; + + private static Consumer> authenticationConvertersConsumer; + + private static AuthenticationProvider authenticationProvider; + + private static Consumer> authenticationProvidersConsumer; + + private static AuthenticationSuccessHandler authenticationSuccessHandler; + + private static AuthenticationFailureHandler authenticationFailureHandler; + + private static Function userInfoMapper; + @BeforeClass public static void init() { securityContextRepository = spy(new HttpSessionSecurityContextRepository()); + authenticationConverter = mock(AuthenticationConverter.class); + authenticationConvertersConsumer = mock(Consumer.class); + authenticationProvider = mock(AuthenticationProvider.class); + authenticationProvidersConsumer = mock(Consumer.class); + authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class); + authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + userInfoMapper = mock(Function.class); } @Before public void setup() { reset(securityContextRepository); + reset(authenticationConverter); + reset(authenticationConvertersConsumer); + reset(authenticationProvider); + reset(authenticationProvidersConsumer); + reset(authenticationSuccessHandler); + reset(authenticationFailureHandler); + reset(userInfoMapper); } @Test @@ -146,19 +195,91 @@ public class OidcUserInfoTests { } @Test - public void requestWhenSignedJwtAndCustomUserInfoMapperThenMapJwtClaimsToUserInfoResponse() throws Exception { + public void requestWhenUserInfoEndpointCustomizedThenUsed() throws Exception { this.spring.register(CustomUserInfoConfiguration.class).autowire(); OAuth2Authorization authorization = createAuthorization(); this.authorizationService.save(authorization); + when(userInfoMapper.apply(any())).thenReturn(createUserInfo()); + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); // @formatter:off this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) - .andExpect(status().is2xxSuccessful()) - .andExpectAll(userInfoResponse()); + .andExpect(status().is2xxSuccessful()); // @formatter:on + + verify(userInfoMapper).apply(any()); + verify(authenticationConverter).convert(any()); + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + verifyNoInteractions(authenticationFailureHandler); + + ArgumentCaptor> authenticationProvidersCaptor = ArgumentCaptor.forClass(List.class); + verify(authenticationProvidersConsumer).accept(authenticationProvidersCaptor.capture()); + List authenticationProviders = authenticationProvidersCaptor.getValue(); + assertThat(authenticationProviders).hasSize(2).allMatch(provider -> + provider == authenticationProvider || + provider instanceof OidcUserInfoAuthenticationProvider + ); + + ArgumentCaptor> authenticationConvertersCaptor = ArgumentCaptor.forClass(List.class); + verify(authenticationConvertersConsumer).accept(authenticationConvertersCaptor.capture()); + List authenticationConverters = authenticationConvertersCaptor.getValue(); + assertThat(authenticationConverters).hasSize(2).allMatch(AuthenticationConverter.class::isInstance); + } + + @Test + public void requestWhenUserInfoEndpointCustomizedWithAuthenticationProviderThenUsed() throws Exception { + this.spring.register(CustomUserInfoConfiguration.class).autowire(); + + OAuth2Authorization authorization = createAuthorization(); + this.authorizationService.save(authorization); + + when(authenticationProvider.supports(eq(OidcUserInfoAuthenticationToken.class))).thenReturn(true); + String tokenValue = authorization.getAccessToken().getToken().getTokenValue(); + Jwt jwt = this.jwtDecoder.decode(tokenValue); + OidcUserInfoAuthenticationToken oidcUserInfoAuthentication = new OidcUserInfoAuthenticationToken( + new JwtAuthenticationToken(jwt), createUserInfo()); + when(authenticationProvider.authenticate(any())).thenReturn(oidcUserInfoAuthentication); + + OAuth2AccessToken accessToken = authorization.getAccessToken().getToken(); + // @formatter:off + this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) + .andExpect(status().is2xxSuccessful()); + // @formatter:on + + verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any()); + verify(authenticationProvider).authenticate(any()); + verifyNoInteractions(authenticationFailureHandler); + verifyNoInteractions(userInfoMapper); + } + + @Test + public void requestWhenUserInfoEndpointCustomizedWithAuthenticationFailureHandlerThenUsed() throws Exception { + this.spring.register(CustomUserInfoConfiguration.class).autowire(); + + when(userInfoMapper.apply(any())).thenReturn(createUserInfo()); + doAnswer( + invocation -> { + HttpServletResponse response = invocation.getArgument(1); + response.setStatus(HttpStatus.UNAUTHORIZED.value()); + response.getWriter().write("unauthorized"); + return null; + } + ).when(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); + + OAuth2AccessToken accessToken = createAuthorization().getAccessToken().getToken(); + // @formatter:off + this.mvc.perform(get(DEFAULT_OIDC_USER_INFO_ENDPOINT_URI) + .header(HttpHeaders.AUTHORIZATION, "Bearer " + accessToken.getTokenValue())) + .andExpect(status().is4xxClientError()); + // @formatter:on + + verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any()); + verifyNoInteractions(authenticationSuccessHandler); + verifyNoInteractions(userInfoMapper); } // gh-482 @@ -273,14 +394,6 @@ public class OidcUserInfoTests { RequestMatcher endpointsMatcher = authorizationServerConfigurer .getEndpointsMatcher(); - // Custom User Info Mapper that retrieves claims from a signed JWT - Function userInfoMapper = context -> { - OidcUserInfoAuthenticationToken authentication = context.getAuthentication(); - JwtAuthenticationToken principal = (JwtAuthenticationToken) authentication.getPrincipal(); - - return new OidcUserInfo(principal.getToken().getClaims()); - }; - // @formatter:off http .securityMatcher(endpointsMatcher) @@ -292,6 +405,12 @@ public class OidcUserInfoTests { .apply(authorizationServerConfigurer) .oidc(oidc -> oidc .userInfoEndpoint(userInfo -> userInfo + .userInfoRequestConverter(authenticationConverter) + .userInfoRequestConverters(authenticationConvertersConsumer) + .authenticationProvider(authenticationProvider) + .authenticationProviders(authenticationProvidersConsumer) + .userInfoResponseHandler(authenticationSuccessHandler) + .errorResponseHandler(authenticationFailureHandler) .userInfoMapper(userInfoMapper) ) ); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java index e471bd8a..e5b6210b 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java @@ -134,6 +134,13 @@ public class OidcClientRegistrationAuthenticationProviderTests { .withMessage("tokenGenerator cannot be null"); } + @Test + public void setRegisteredClientConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authenticationProvider.setRegisteredClientConverter(null)) + .withMessage("registeredClientConverter cannot be null"); + } + @Test public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() { assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue(); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java index 144879fc..f6ff8f54 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java @@ -15,10 +15,12 @@ */ package org.springframework.security.oauth2.server.authorization.oidc.web; +import java.io.IOException; import java.time.Instant; import java.util.Collections; import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; @@ -33,6 +35,8 @@ import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AuthenticationManager; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.Authentication; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; @@ -54,10 +58,14 @@ import org.springframework.security.oauth2.server.authorization.oidc.OidcClientR import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcClientRegistrationAuthenticationToken; import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -68,6 +76,7 @@ import static org.mockito.Mockito.when; * * @author Ovidiu Popa * @author Joe Grandja + * @author Daniel Garnier-Moiroux */ public class OidcClientRegistrationEndpointFilterTests { private static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register"; @@ -103,6 +112,27 @@ public class OidcClientRegistrationEndpointFilterTests { .withMessage("clientRegistrationEndpointUri cannot be empty"); } + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + } + @Test public void doFilterWhenNotClientRegistrationRequestThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -203,25 +233,13 @@ public class OidcClientRegistrationEndpointFilterTests { @Test public void doFilterWhenClientRegistrationRequestValidThenSuccessResponse() throws Exception { // @formatter:off - OidcClientRegistration.Builder clientRegistrationBuilder = OidcClientRegistration.builder() - .clientName("client-name") - .redirectUri("https://client.example.com") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .scope("scope1") - .scope("scope2"); + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); - OidcClientRegistration clientRegistrationRequest = clientRegistrationBuilder.build(); - - OidcClientRegistration expectedClientRegistrationResponse = clientRegistrationBuilder - .clientId("client-id") - .clientIdIssuedAt(Instant.now()) - .clientSecret("client-secret") - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) - .registrationAccessToken("registration-access-token") - .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") + OidcClientRegistration clientRegistrationRequest = OidcClientRegistration.builder() + .clientName(expectedClientRegistrationResponse.getClientName()) + .redirectUris(redirectUris -> redirectUris.addAll(expectedClientRegistrationResponse.getRedirectUris())) + .grantTypes(grantTypes -> grantTypes.addAll(expectedClientRegistrationResponse.getGrantTypes())) + .scopes(scopes -> scopes.addAll(expectedClientRegistrationResponse.getScopes())) .build(); // @formatter:on @@ -384,23 +402,7 @@ public class OidcClientRegistrationEndpointFilterTests { @Test public void doFilterWhenClientConfigurationRequestValidThenSuccessResponse() throws Exception { - // @formatter:off - OidcClientRegistration expectedClientRegistrationResponse = OidcClientRegistration.builder() - .clientId("client-id") - .clientIdIssuedAt(Instant.now()) - .clientSecret("client-secret") - .clientName("client-name") - .redirectUri("https://client.example.com") - .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) - .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) - .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) - .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) - .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) - .scope("scope1") - .scope("scope2") - .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") - .build(); - // @formatter:on + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); Jwt jwt = createJwt("client.read"); JwtAuthenticationToken principal = new JwtAuthenticationToken( @@ -452,6 +454,74 @@ public class OidcClientRegistrationEndpointFilterTests { .isEqualTo(expectedClientRegistrationResponse.getRegistrationClientUrl()); } + @Test + public void doFilterWhenCustomAuthenticationConverterThenUsed() throws ServletException, IOException { + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + this.filter.setAuthenticationConverter(authenticationConverter); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id"); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationConverter).convert(request); + } + + @Test + public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception { + OidcClientRegistration expectedClientRegistrationResponse = createClientRegistration(); + Authentication principal = new TestingAuthenticationToken("principal", "Credentials"); + + OidcClientRegistrationAuthenticationToken clientRegistrationAuthenticationResult = + new OidcClientRegistrationAuthenticationToken(principal, expectedClientRegistrationResponse); + + when(this.authenticationManager.authenticate(any())).thenReturn(clientRegistrationAuthenticationResult); + AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(successHandler); + + SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); + securityContext.setAuthentication(principal); + SecurityContextHolder.setContext(securityContext); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId()); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(successHandler).onAuthenticationSuccess(request, response, clientRegistrationAuthenticationResult); + } + + @Test + public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception { + AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(authenticationFailureHandler); + + when(this.authenticationManager.authenticate(any())) + .thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN)); + + String requestUri = DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1"); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationFailureHandler).onAuthenticationFailure(eq(request), eq(response), + any(OAuth2AuthenticationException.class)); + } + private OAuth2Error readError(MockHttpServletResponse response) throws Exception { MockClientHttpResponse httpResponse = new MockClientHttpResponse( response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus())); @@ -471,6 +541,27 @@ public class OidcClientRegistrationEndpointFilterTests { return this.clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse); } + private static OidcClientRegistration createClientRegistration() { + // @formatter:off + OidcClientRegistration clientRegistration = OidcClientRegistration.builder() + .clientId("client-id") + .clientIdIssuedAt(Instant.now()) + .clientSecret("client-secret") + .clientName("client-name") + .redirectUri("https://client.example.com") + .grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue()) + .grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()) + .tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue()) + .responseType(OAuth2AuthorizationResponseType.CODE.getValue()) + .idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName()) + .scope("scope1") + .scope("scope2") + .registrationClientUrl("https://auth-server:9000/connect/register?client_id=client-id") + .build(); + return clientRegistration; + // @formatter:on + } + private static Jwt createJwt(String scope) { // @formatter:off JwsHeader jwsHeader = TestJwsHeaders.jwsHeader() diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java index ccfa9179..f7da0545 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcUserInfoEndpointFilterTests.java @@ -44,6 +44,9 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.server.authorization.oidc.authentication.OidcUserInfoAuthenticationToken; import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken; +import org.springframework.security.web.authentication.AuthenticationConverter; +import org.springframework.security.web.authentication.AuthenticationFailureHandler; +import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -84,6 +87,27 @@ public class OidcUserInfoEndpointFilterTests { .withMessage("userInfoEndpointUri cannot be empty"); } + @Test + public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationConverter(null)) + .withMessage("authenticationConverter cannot be null"); + } + + @Test + public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null)) + .withMessage("authenticationSuccessHandler cannot be null"); + } + + @Test + public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null)) + .withMessage("authenticationFailureHandler cannot be null"); + } + @Test public void doFilterWhenNotUserInfoRequestThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -145,11 +169,21 @@ public class OidcUserInfoEndpointFilterTests { @Test public void doFilterWhenUserInfoRequestInvalidTokenThenUnauthorizedError() throws Exception { + doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INVALID_TOKEN, HttpStatus.UNAUTHORIZED); + } + + @Test + public void doFilterWhenUserInfoRequestInsufficientScopeThenForbiddenError() throws Exception { + doFilterWhenAuthenticationExceptionThenError(OAuth2ErrorCodes.INSUFFICIENT_SCOPE, HttpStatus.FORBIDDEN); + } + + private void doFilterWhenAuthenticationExceptionThenError(String oauth2ErrorCode, HttpStatus httpStatus) + throws Exception { Authentication principal = new TestingAuthenticationToken("principal", "credentials"); SecurityContextHolder.getContext().setAuthentication(principal); when(this.authenticationManager.authenticate(any())) - .thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN)); + .thenThrow(new OAuth2AuthenticationException(oauth2ErrorCode)); String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); @@ -161,9 +195,82 @@ public class OidcUserInfoEndpointFilterTests { verifyNoInteractions(filterChain); - assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value()); + assertThat(response.getStatus()).isEqualTo(httpStatus.value()); OAuth2Error error = readError(response); - assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_TOKEN); + assertThat(error.getErrorCode()).isEqualTo(oauth2ErrorCode); + } + + @Test + public void doFilterWhenCustomAuthenticationConverterThenUsed() throws Exception { + Authentication principal = new TestingAuthenticationToken("principal", "credentials"); + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal); + AuthenticationConverter authenticationConverter = mock(AuthenticationConverter.class); + this.filter.setAuthenticationConverter(authenticationConverter); + + when(authenticationConverter.convert(any())).thenReturn(authentication); + when(this.authenticationManager.authenticate(any())).thenReturn( + new OidcUserInfoAuthenticationToken(principal, createUserInfo()) + ); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + verify(authenticationConverter).convert(request); + verify(this.authenticationManager).authenticate(authentication); + assertUserInfoResponse(response.getContentAsString()); + } + + @Test + public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception { + AuthenticationSuccessHandler successHandler = mock(AuthenticationSuccessHandler.class); + this.filter.setAuthenticationSuccessHandler(successHandler); + + Authentication principal = new TestingAuthenticationToken("principal", "credentials"); + SecurityContextHolder.getContext().setAuthentication(principal); + + OidcUserInfoAuthenticationToken authentication = new OidcUserInfoAuthenticationToken(principal, createUserInfo()); + when(this.authenticationManager.authenticate(any())).thenReturn(authentication); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + verify(successHandler).onAuthenticationSuccess(request, response, authentication); + } + + @Test + public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception { + AuthenticationFailureHandler failureHandler = mock(AuthenticationFailureHandler.class); + this.filter.setAuthenticationFailureHandler(failureHandler); + + Authentication principal = new TestingAuthenticationToken("principal", "credentials"); + SecurityContextHolder.getContext().setAuthentication(principal); + + OAuth2AuthenticationException authenticationException = + new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); + when(this.authenticationManager.authenticate(any())).thenThrow(authenticationException); + + String requestUri = DEFAULT_OIDC_USER_INFO_ENDPOINT_URI; + MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); + request.setServletPath(requestUri); + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verifyNoInteractions(filterChain); + verify(failureHandler).onAuthenticationFailure(request, response, authenticationException); } private OAuth2Error readError(MockHttpServletResponse response) throws Exception {