diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java index 56db12a0..006d9da2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java @@ -19,6 +19,9 @@ import java.security.Principal; import java.time.Duration; import java.time.Instant; import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import org.springframework.beans.factory.annotation.Autowired; @@ -35,17 +38,20 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.core.OAuth2TokenType; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.jwt.JoseHeader; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtClaimsSet; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.config.ProviderSettings; import org.springframework.security.oauth2.server.authorization.config.TokenSettings; -import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; import org.springframework.util.Assert; import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient; @@ -55,6 +61,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic * * @author Alexey Nesterov * @author Joe Grandja + * @author Anoop Garlapati * @since 0.0.3 * @see OAuth2RefreshTokenAuthenticationToken * @see OAuth2AccessTokenAuthenticationToken @@ -66,6 +73,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic * @see Section 6 Refreshing an Access Token */ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider { + private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN); private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96); private final OAuth2AuthorizationService authorizationService; private final JwtEncoder jwtEncoder; @@ -174,19 +182,64 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP currentRefreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive()); } + Jwt jwtIdToken = null; + if (authorizedScopes.contains(OidcScopes.OPENID)) { + headersBuilder = JwtUtils.headers(); + claimsBuilder = JwtUtils.idTokenClaims( + registeredClient, issuer, authorization.getPrincipalName(), null); + + // @formatter:off + context = JwtEncodingContext.with(headersBuilder, claimsBuilder) + .registeredClient(registeredClient) + .principal(authorization.getAttribute(Principal.class.getName())) + .authorization(authorization) + .authorizedScopes(authorizedScopes) + .tokenType(ID_TOKEN_TOKEN_TYPE) + .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN) + .authorizationGrant(refreshTokenAuthentication) + .build(); + // @formatter:on + + this.jwtCustomizer.customize(context); + + headers = context.getHeaders().build(); + claims = context.getClaims().build(); + jwtIdToken = this.jwtEncoder.encode(headers, claims); + } + + OidcIdToken idToken; + if (jwtIdToken != null) { + idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(), + jwtIdToken.getExpiresAt(), jwtIdToken.getClaims()); + } else { + idToken = null; + } + // @formatter:off - authorization = OAuth2Authorization.from(authorization) + OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization) .token(accessToken, (metadata) -> metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims())) - .refreshToken(currentRefreshToken) - .build(); + .refreshToken(currentRefreshToken); + if (idToken != null) { + authorizationBuilder + .token(idToken, + (metadata) -> + metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims())); + } + authorization = authorizationBuilder.build(); // @formatter:on this.authorizationService.save(authorization); + Map additionalParameters = Collections.emptyMap(); + if (idToken != null) { + additionalParameters = new HashMap<>(); + additionalParameters.put(OidcParameterNames.ID_TOKEN, idToken.getTokenValue()); + } + return new OAuth2AccessTokenAuthenticationToken( - registeredClient, clientPrincipal, accessToken, currentRefreshToken); + registeredClient, clientPrincipal, accessToken, currentRefreshToken, additionalParameters); } @Override diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java index 65a1fc1e..6823b0b3 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java @@ -19,7 +19,9 @@ import java.security.Principal; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; import org.junit.Before; @@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken2; import org.springframework.security.oauth2.core.OAuth2TokenType; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; +import org.springframework.security.oauth2.core.oidc.OidcIdToken; +import org.springframework.security.oauth2.core.oidc.OidcScopes; +import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames; import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm; import org.springframework.security.oauth2.jwt.JoseHeaderNames; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.JwtEncoder; +import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; +import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; -import org.springframework.security.oauth2.server.authorization.JwtEncodingContext; -import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer; +import static org.assertj.core.api.Assertions.entry; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -61,6 +68,7 @@ import static org.mockito.Mockito.when; * * @author Alexey Nesterov * @author Joe Grandja + * @author Anoop Garlapati * @since 0.0.3 */ public class OAuth2RefreshTokenAuthenticationProviderTests { @@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests { assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken()); } + @Test + public void authenticateWhenValidRefreshTokenThenReturnIdToken() { + RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build(); + OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build(); + when(this.authorizationService.findByToken( + eq(authorization.getRefreshToken().getToken().getTokenValue()), + eq(OAuth2TokenType.REFRESH_TOKEN))) + .thenReturn(authorization); + + OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient); + OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken( + authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null); + + OAuth2AccessTokenAuthenticationToken accessTokenAuthentication = + (OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication); + + ArgumentCaptor jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class); + verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture()); + // Access Token context + JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0); + assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(accessTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(accessTokenContext.getAuthorizedScopes()) + .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); + assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN); + assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(accessTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(accessTokenContext.getHeaders()).isNotNull(); + assertThat(accessTokenContext.getClaims()).isNotNull(); + Map claims = new HashMap<>(); + accessTokenContext.getClaims().claims(claims::putAll); + assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE) + .containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1"); + // ID Token context + JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1); + assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient); + assertThat(idTokenContext.getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName())); + assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization); + assertThat(idTokenContext.getAuthorizedScopes()) + .isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME)); + assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN); + assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN); + assertThat(idTokenContext.getAuthorizationGrant()).isEqualTo(authentication); + assertThat(idTokenContext.getHeaders()).isNotNull(); + assertThat(idTokenContext.getClaims()).isNotNull(); + + verify(this.jwtEncoder, times(2)).encode(any(), any()); // Access token and ID Token + + ArgumentCaptor authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class); + verify(this.authorizationService).save(authorizationCaptor.capture()); + OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue(); + + assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId()); + assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal); + assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken()); + assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken()); + OAuth2Authorization.Token idToken = updatedAuthorization.getToken(OidcIdToken.class); + assertThat(idToken).isNotNull(); + assertThat(accessTokenAuthentication.getAdditionalParameters()) + .containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue())); + assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken()); + // By default, refresh token is reused + assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken()); + } + @Test public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() { RegisteredClient registeredClient = TestRegisteredClients.registeredClient()