client_id authentication parameter must have printable ASCII characters

Closes gh-889
This commit is contained in:
Joe Grandja
2022-11-17 14:51:44 -05:00
parent fcbb5c1197
commit 8ed0194744
2 changed files with 68 additions and 0 deletions

View File

@@ -118,6 +118,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
this.authenticationDetailsSource.buildDetails(request));
}
if (authenticationRequest != null) {
validateClientIdentifier(authenticationRequest);
Authentication authenticationResult = this.authenticationManager.authenticate(authenticationRequest);
this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, authenticationResult);
}
@@ -201,4 +202,25 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
this.errorHttpResponseConverter.write(errorResponse, null, httpResponse);
}
private static void validateClientIdentifier(Authentication authentication) {
if (!(authentication instanceof OAuth2ClientAuthenticationToken)) {
return;
}
// As per spec, in Appendix A.1.
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-07#appendix-A.1
// The syntax for client_id is *VSCHAR (%x20-7E):
// -> Hex 20 -> ASCII 32 -> space
// -> Hex 7E -> ASCII 126 -> tilde
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
String clientId = (String) clientAuthentication.getPrincipal();
for (int i = 0; i < clientId.length(); i++) {
char charAt = clientId.charAt(i);
if (!(charAt >= 32 && charAt <= 126)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}
}
}
}

View File

@@ -15,6 +15,8 @@
*/
package org.springframework.security.oauth2.server.authorization.web;
import java.nio.charset.StandardCharsets;
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -33,6 +35,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.crypto.codec.Hex;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
@@ -130,6 +133,7 @@ public class OAuth2ClientAuthenticationFilterTests {
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoInteractions(this.authenticationConverter);
}
@Test
@@ -142,6 +146,7 @@ public class OAuth2ClientAuthenticationFilterTests {
this.filter.doFilter(request, response, filterChain);
verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
verifyNoInteractions(this.authenticationManager);
}
@Test
@@ -164,6 +169,46 @@ public class OAuth2ClientAuthenticationFilterTests {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
// gh-889
@Test
public void doFilterWhenRequestMatchesAndClientIdContainsNonPrintableASCIIThenInvalidRequestError() throws Exception {
// Hex 00 -> null
String clientId = new String(Hex.decode("00"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
// Hex 0a61 -> line feed + a
clientId = new String(Hex.decode("0a61"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
// Hex 1b -> escape
clientId = new String(Hex.decode("1b"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
// Hex 1b61 -> escape + a
clientId = new String(Hex.decode("1b61"), StandardCharsets.UTF_8);
assertWhenInvalidClientIdThenInvalidRequestError(clientId);
}
private void assertWhenInvalidClientIdThenInvalidRequestError(String clientId) throws Exception {
when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null));
MockHttpServletRequest request = new MockHttpServletRequest("POST", this.filterProcessesUrl);
request.setServletPath(this.filterProcessesUrl);
MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class);
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
verifyNoInteractions(this.authenticationManager);
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
OAuth2Error error = readError(response);
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
}
@Test
public void doFilterWhenRequestMatchesAndBadCredentialsThenInvalidClientError() throws Exception {
when(this.authenticationConverter.convert(any(HttpServletRequest.class))).thenReturn(
@@ -179,6 +224,7 @@ public class OAuth2ClientAuthenticationFilterTests {
this.filter.doFilter(request, response, filterChain);
verifyNoInteractions(filterChain);
verify(this.authenticationManager).authenticate(any());
assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
assertThat(response.getStatus()).isEqualTo(HttpStatus.UNAUTHORIZED.value());