Add authenticationDetailsSource to AuthorizationEndpointFilter
Closes gh-768
This commit is contained in:
committed by
Joe Grandja
parent
fdf0a2f94c
commit
ec7ab5c956
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2020-2021 the original author or authors.
|
||||
* Copyright 2020-2022 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -28,6 +28,7 @@ import javax.servlet.http.HttpServletResponse;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.security.authentication.AuthenticationDetailsSource;
|
||||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.AuthenticationException;
|
||||
@@ -45,6 +46,7 @@ import org.springframework.security.web.RedirectStrategy;
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
||||
import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
|
||||
import org.springframework.security.web.util.RedirectUrlBuilder;
|
||||
import org.springframework.security.web.util.UrlUtils;
|
||||
import org.springframework.security.web.util.matcher.AndRequestMatcher;
|
||||
@@ -82,6 +84,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
||||
private final AuthenticationManager authenticationManager;
|
||||
private final RequestMatcher authorizationEndpointMatcher;
|
||||
private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
|
||||
private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
|
||||
private AuthenticationConverter authenticationConverter;
|
||||
private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse;
|
||||
private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
|
||||
@@ -144,6 +147,7 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
||||
try {
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
|
||||
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationConverter.convert(request);
|
||||
authorizationCodeRequestAuthentication.setDetails(this.authenticationDetailsSource.buildDetails(request));
|
||||
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
|
||||
(OAuth2AuthorizationCodeRequestAuthenticationToken) this.authenticationManager.authenticate(authorizationCodeRequestAuthentication);
|
||||
@@ -169,6 +173,17 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}.
|
||||
*
|
||||
* @param authenticationDetailsSource the {@link AuthenticationDetailsSource} used for building an authentication details instance from {@link HttpServletRequest}
|
||||
* @since 0.3.1
|
||||
*/
|
||||
public void setAuthenticationDetailsSource(AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
|
||||
Assert.notNull(authenticationDetailsSource, "authenticationDetailsSource cannot be null");
|
||||
this.authenticationDetailsSource = authenticationDetailsSource;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets the {@link AuthenticationConverter} used when attempting to extract an Authorization Request (or Consent) from {@link HttpServletRequest}
|
||||
* to an instance of {@link OAuth2AuthorizationCodeRequestAuthenticationToken} used for authenticating the request.
|
||||
|
||||
@@ -32,10 +32,12 @@ import org.junit.After;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.mockito.ArgumentCaptor;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.mock.web.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.authentication.AuthenticationDetailsSource;
|
||||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.authentication.TestingAuthenticationToken;
|
||||
import org.springframework.security.core.Authentication;
|
||||
@@ -55,10 +57,12 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
|
||||
import org.springframework.security.web.authentication.AuthenticationConverter;
|
||||
import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
||||
import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
||||
import org.springframework.security.web.authentication.WebAuthenticationDetails;
|
||||
import org.springframework.util.StringUtils;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
import static org.assertj.core.api.InstanceOfAssertFactories.type;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.same;
|
||||
import static org.mockito.Mockito.mock;
|
||||
@@ -78,6 +82,7 @@ import static org.mockito.Mockito.when;
|
||||
*/
|
||||
public class OAuth2AuthorizationEndpointFilterTests {
|
||||
private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
|
||||
private static final String REMOTE_ADDRESS = "remote-address";
|
||||
private AuthenticationManager authenticationManager;
|
||||
private OAuth2AuthorizationEndpointFilter filter;
|
||||
private TestingAuthenticationToken principal;
|
||||
@@ -116,6 +121,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
.hasMessage("authorizationEndpointUri cannot be empty");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthenticationDetailsSourceWhenNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.filter.setAuthenticationDetailsSource(null))
|
||||
.isInstanceOf(IllegalArgumentException.class)
|
||||
.hasMessage("authenticationDetailsSource cannot be null");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void setAuthenticationConverterWhenNullThenThrowIllegalArgumentException() {
|
||||
assertThatThrownBy(() -> this.filter.setAuthenticationConverter(null))
|
||||
@@ -364,6 +376,32 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenCustomAuthenticationDetailsSourceThenUsed() throws Exception {
|
||||
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
|
||||
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
|
||||
authorizationCodeRequestAuthentication(registeredClient, this.principal).build();
|
||||
MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
|
||||
|
||||
AuthenticationDetailsSource<HttpServletRequest, WebAuthenticationDetails> authenticationDetailsSource =
|
||||
mock(AuthenticationDetailsSource.class);
|
||||
WebAuthenticationDetails webAuthenticationDetails = new WebAuthenticationDetails(request);
|
||||
when(authenticationDetailsSource.buildDetails(request)).thenReturn(webAuthenticationDetails);
|
||||
this.filter.setAuthenticationDetailsSource(authenticationDetailsSource);
|
||||
|
||||
when(this.authenticationManager.authenticate(any()))
|
||||
.thenReturn(authorizationCodeRequestAuthentication);
|
||||
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(authenticationDetailsSource).buildDetails(any());
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
|
||||
this.principal.setAuthenticated(false);
|
||||
@@ -507,9 +545,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
ArgumentCaptor<OAuth2AuthorizationCodeRequestAuthenticationToken> authorizationCodeRequestAuthenticationCaptor =
|
||||
ArgumentCaptor.forClass(OAuth2AuthorizationCodeRequestAuthenticationToken.class);
|
||||
verify(this.authenticationManager).authenticate(authorizationCodeRequestAuthenticationCaptor.capture());
|
||||
verifyNoInteractions(filterChain);
|
||||
|
||||
assertThat(authorizationCodeRequestAuthenticationCaptor.getValue().getDetails())
|
||||
.asInstanceOf(type(WebAuthenticationDetails.class))
|
||||
.extracting(WebAuthenticationDetails::getRemoteAddress)
|
||||
.isEqualTo(REMOTE_ADDRESS);
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
|
||||
assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
|
||||
}
|
||||
@@ -578,6 +622,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.setRemoteAddr(REMOTE_ADDRESS);
|
||||
|
||||
request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
|
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
||||
@@ -593,6 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
|
||||
String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
|
||||
request.setServletPath(requestUri);
|
||||
request.setRemoteAddr(REMOTE_ADDRESS);
|
||||
|
||||
request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
|
||||
registeredClient.getScopes().forEach((scope) -> request.addParameter(OAuth2ParameterNames.SCOPE, scope));
|
||||
|
||||
Reference in New Issue
Block a user