diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java index 3efb79fd..3c1232b0 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java @@ -1,5 +1,5 @@ /* - * Copyright 2020-2021 the original author or authors. + * Copyright 2020-2022 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; +import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; @@ -45,6 +46,7 @@ import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetailsSource; import org.springframework.security.web.util.RedirectUrlBuilder; import org.springframework.security.web.util.UrlUtils; import org.springframework.security.web.util.matcher.AndRequestMatcher; @@ -82,6 +84,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte private final AuthenticationManager authenticationManager; private final RequestMatcher authorizationEndpointMatcher; private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + private AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private AuthenticationConverter authenticationConverter; private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse; private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse; @@ -144,6 +147,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte try { OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationConverter.convert(request); + authorizationCodeRequestAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request)); OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult = (OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationManager.authenticate(authorizationCodeRequestAuthentication); @@ -169,6 +173,17 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte } } + /** + * Sets the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}. + * + * @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest} + * @since 0.3.1 + */ + public void setAuthenticationDetailsSource(AuthenticationDetailsSource authenticationDetailsSource) { + Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null"); + this.authenticationDetailsSource = authenticationDetailsSource; + } + /** * Sets the {@link AuthenticationConverter} used when attempting to extract an Authorization Request (or Consent) from {@link HttpServletRequest} * to an instance of {@link OAuth2AuthorizationCodeRequestAuthenticationToken} used for authenticating the request. diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java index 232d6e23..c2135f5c 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java @@ -32,10 +32,12 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import org.springframework.http.HttpStatus; import org.springframework.http.MediaType; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.authentication.AuthenticationDetailsSource; import org.springframework.security.authentication.AuthenticationManager; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.Authentication; @@ -55,10 +57,12 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis import org.springframework.security.web.authentication.AuthenticationConverter; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; +import org.springframework.security.web.authentication.WebAuthenticationDetails; import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.InstanceOfAssertFactories.type; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; @@ -78,6 +82,7 @@ import static org.mockito.Mockito.when; */ public class OAuth2AuthorizationEndpointFilterTests { private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize"; + private static final String REMOTE_ADDRESS = "remote-address"; private AuthenticationManager authenticationManager; private OAuth2AuthorizationEndpointFilter filter; private TestingAuthenticationToken principal; @@ -116,6 +121,13 @@ public class OAuth2AuthorizationEndpointFilterTests { .hasMessage("authorizationEndpointUri cannot be empty"); } + @Test + public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authenticationDetailsSource cannot be null"); + } + @Test public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() { assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null)) @@ -364,6 +376,32 @@ public class OAuth2AuthorizationEndpointFilterTests { verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException)); } + @Test + public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build(); + OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication = + authorizationCodeRequestAuthentication(registeredClient, this.principal).build(); + MockHttpServletRequest request = createAuthorizationRequest(registeredClient); + + AuthenticationDetailsSource authenticationDetailsSource = + mock(AuthenticationDetailsSource.class); + WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request); + when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails); + this.filter.setAuthenticationDetailsSource(authenticationDetailsSource); + + when(this.authenticationManager.authenticate(any())) + .thenReturn(authorizationCodeRequestAuthentication); + + MockHttpServletResponse response = new MockHttpServletResponse(); + FilterChain filterChain = mock(FilterChain.class); + + this.filter.doFilter(request, response, filterChain); + + verify(authenticationDetailsSource).buildDetails(any()); + verify(this.authenticationManager).authenticate(any()); + verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + @Test public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception { this.principal.setAuthenticated(false); @@ -507,9 +545,15 @@ public class OAuth2AuthorizationEndpointFilterTests { this.filter.doFilter(request, response, filterChain); - verify(this.authenticationManager).authenticate(any()); + ArgumentCaptor authorizationCodeRequestAuthenticationCaptor = + ArgumentCaptor.forClass(OAuth2AuthorizationCodeRequestAuthenticationToken.class); + verify(this.authenticationManager).authenticate(authorizationCodeRequestAuthenticationCaptor.capture()); verifyNoInteractions(filterChain); + assertThat(authorizationCodeRequestAuthenticationCaptor.getValue().getDetails()) + .asInstanceOf(type(WebAuthenticationDetails.class)) + .extracting(WebAuthenticationDetails::getRemoteAddress) + .isEqualTo(REMOTE_ADDRESS); assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value()); assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state"); } @@ -578,6 +622,7 @@ public class OAuth2AuthorizationEndpointFilterTests { String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); request.setServletPath(requestUri); + request.setRemoteAddr(REMOTE_ADDRESS); request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue()); request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); @@ -593,6 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests { String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI; MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri); request.setServletPath(requestUri); + request.setRemoteAddr(REMOTE_ADDRESS); request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId()); registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));