client_id authentication parameter must have printable ASCII characters
Closes gh-889
This commit is contained in:
@@ -118,6 +118,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
||||
this.authenticationDetailsSource.buildDetails(request));
|
||||
}
|
||||
if (authenticationRequest != null) {
|
||||
validateClientIdentifier(authenticationRequest);
|
||||
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
|
||||
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
|
||||
}
|
||||
@@ -201,4 +202,25 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
|
||||
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
|
||||
}
|
||||
|
||||
private static void validateClientIdentifier(Authentication authentication) {
|
||||
if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// As per spec, in Appendix A.1.
|
||||
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-07#appendix-A.1
|
||||
// The syntax for client_id is *VSCHAR (%x20-7E):
|
||||
// -> Hex 20 -> ASCII 32 -> space
|
||||
// -> Hex 7E -> ASCII 126 -> tilde
|
||||
|
||||
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
|
||||
String clientId = (String) clientAuthentication.getPrincipal();
|
||||
for (int i = 0; i < clientId.length(); i++) {
|
||||
char charAt = clientId.charAt(i);
|
||||
if (!(charAt >= 32 && charAt <= 126)) {
|
||||
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
*/
|
||||
package org.springframework.security.oauth2.server.authorization.web;
|
||||
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
import javax.servlet.FilterChain;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
@@ -33,6 +35,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
|
||||
import org.springframework.security.authentication.AuthenticationManager;
|
||||
import org.springframework.security.core.Authentication;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.security.crypto.codec.Hex;
|
||||
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
|
||||
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
|
||||
import org.springframework.security.oauth2.core.OAuth2Error;
|
||||
@@ -130,6 +133,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
verifyNoInteractions(this.authenticationConverter);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -142,6 +146,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
|
||||
verifyNoInteractions(this.authenticationManager);
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -164,6 +169,46 @@ public class OAuth2ClientAuthenticationFilterTests {
|
||||
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
// gh-889
|
||||
@Test
|
||||
public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError() throws Exception {
|
||||
// Hex 00 -> null
|
||||
String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
|
||||
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
||||
|
||||
// Hex 0a61 -> line feed + a
|
||||
clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
|
||||
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
||||
|
||||
// Hex 1b -> escape
|
||||
clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
|
||||
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
||||
|
||||
// Hex 1b61 -> escape + a
|
||||
clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
|
||||
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
|
||||
}
|
||||
|
||||
private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
|
||||
when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
|
||||
new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null));
|
||||
|
||||
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl);
|
||||
request.setServletPath(this.filterProcessesUrl);
|
||||
MockHttpServletResponse response = new MockHttpServletResponse();
|
||||
FilterChain filterChain = mock(FilterChain.class);
|
||||
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
verifyNoInteractions(this.authenticationManager);
|
||||
|
||||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
|
||||
OAuth2Error error = readError(response);
|
||||
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
|
||||
when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
|
||||
@@ -179,6 +224,7 @@ public class OAuth2ClientAuthenticationFilterTests {
|
||||
this.filter.doFilter(request, response, filterChain);
|
||||
|
||||
verifyNoInteractions(filterChain);
|
||||
verify(this.authenticationManager).authenticate(any());
|
||||
|
||||
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
|
||||
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());
|
||||
|
||||
Reference in New Issue
Block a user