Add OAuth2RefreshTokenGenerator

Closes gh-638
This commit is contained in:
Joe Grandja
2022-02-24 06:43:45 -05:00
parent c799261a72
commit cdb48f510e
11 changed files with 451 additions and 49 deletions

View File

@@ -29,12 +29,14 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
@@ -96,7 +98,8 @@ final class OAuth2ConfigurerUtils {
if (jwtCustomizer != null) {
jwtGenerator.setJwtCustomizer(jwtCustomizer);
}
tokenGenerator = jwtGenerator;
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
tokenGenerator = new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
}
builder.setSharedObject(OAuth2TokenGenerator.class, tokenGenerator);
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.springframework.lang.Nullable;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.util.Assert;
/**
* An {@link OAuth2TokenGenerator} that simply delegates to it's
* internal {@code List} of {@link OAuth2TokenGenerator}(s).
* <p>
* Each {@link OAuth2TokenGenerator} is given a chance to
* {@link OAuth2TokenGenerator#generate(OAuth2TokenContext)}
* with the first {@code non-null} {@link OAuth2Token} being returned.
*
* @author Joe Grandja
* @since 0.2.3
* @see OAuth2TokenGenerator
* @see JwtGenerator
* @see OAuth2RefreshTokenGenerator
*/
public final class DelegatingOAuth2TokenGenerator implements OAuth2TokenGenerator<OAuth2Token> {
private final List<OAuth2TokenGenerator<OAuth2Token>> tokenGenerators;
/**
* Constructs a {@code DelegatingOAuth2TokenGenerator} using the provided parameters.
*
* @param tokenGenerators an array of {@link OAuth2TokenGenerator}(s)
*/
@SafeVarargs
public DelegatingOAuth2TokenGenerator(OAuth2TokenGenerator<? extends OAuth2Token>... tokenGenerators) {
Assert.notEmpty(tokenGenerators, "tokenGenerators cannot be empty");
Assert.noNullElements(tokenGenerators, "tokenGenerator cannot be null");
this.tokenGenerators = Collections.unmodifiableList(asList(tokenGenerators));
}
@Nullable
@Override
public OAuth2Token generate(OAuth2TokenContext context) {
for (OAuth2TokenGenerator<OAuth2Token> tokenGenerator : this.tokenGenerators) {
OAuth2Token token = tokenGenerator.generate(context);
if (token != null) {
return token;
}
}
return null;
}
@SuppressWarnings("unchecked")
private static List<OAuth2TokenGenerator<OAuth2Token>> asList(
OAuth2TokenGenerator<? extends OAuth2Token>... tokenGenerators) {
List<OAuth2TokenGenerator<OAuth2Token>> tokenGeneratorList = new ArrayList<>();
for (OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator : tokenGenerators) {
tokenGeneratorList.add((OAuth2TokenGenerator<OAuth2Token>) tokenGenerator);
}
return tokenGeneratorList;
}
}

View File

@@ -0,0 +1,50 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import java.util.Base64;
import org.springframework.lang.Nullable;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2TokenType;
/**
* An {@link OAuth2TokenGenerator} that generates an {@link OAuth2RefreshToken}.
*
* @author Joe Grandja
* @since 0.2.3
* @see OAuth2TokenGenerator
* @see OAuth2RefreshToken
*/
public final class OAuth2RefreshTokenGenerator implements OAuth2TokenGenerator<OAuth2RefreshToken> {
private final StringKeyGenerator refreshTokenGenerator =
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
@Nullable
@Override
public OAuth2RefreshToken generate(OAuth2TokenContext context) {
if (!OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
}
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(context.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
return new OAuth2RefreshToken(this.refreshTokenGenerator.generateKey(), issuedAt, expiresAt);
}
}

View File

@@ -16,9 +16,7 @@
package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -28,8 +26,6 @@ import java.util.function.Supplier;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -48,10 +44,12 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -83,11 +81,14 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
new OAuth2TokenType(OAuth2ParameterNames.CODE);
private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE =
new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private final OAuth2AuthorizationService authorizationService;
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
// TODO Remove after removing @Deprecated OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService, JwtEncoder)
private JwtGenerator jwtGenerator;
@Deprecated
private Supplier<String> refreshTokenGenerator;
/**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
@@ -101,7 +102,9 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
Assert.notNull(authorizationService, "authorizationService cannot be null");
Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
this.authorizationService = authorizationService;
this.tokenGenerator = new JwtGenerator(jwtEncoder);
this.jwtGenerator = new JwtGenerator(jwtEncoder);
this.tokenGenerator = new DelegatingOAuth2TokenGenerator(
this.jwtGenerator, new OAuth2RefreshTokenGenerator());
}
/**
@@ -130,16 +133,18 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
@Deprecated
public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
if (this.tokenGenerator instanceof JwtGenerator) {
((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer);
if (this.jwtGenerator != null) {
this.jwtGenerator.setJwtCustomizer(jwtCustomizer);
}
}
/**
* Sets the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}.
*
* @deprecated Use {@link OAuth2RefreshTokenGenerator} instead
* @param refreshTokenGenerator the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}
*/
@Deprecated
public void setRefreshTokenGenerator(Supplier<String> refreshTokenGenerator) {
Assert.notNull(refreshTokenGenerator, "refreshTokenGenerator cannot be null");
this.refreshTokenGenerator = refreshTokenGenerator;
@@ -223,7 +228,21 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
// Do not issue refresh token to public client
!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
refreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
if (this.refreshTokenGenerator != null) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
refreshToken = new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
} else {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
refreshToken = (OAuth2RefreshToken) generatedRefreshToken;
}
authorizationBuilder.refreshToken(refreshToken);
}
@@ -267,10 +286,4 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
return OAuth2AuthorizationCodeAuthenticationToken.class.isAssignableFrom(authentication);
}
private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenTimeToLive);
return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
}
}

View File

@@ -16,9 +16,7 @@
package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -29,8 +27,6 @@ import java.util.function.Supplier;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -45,10 +41,12 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -76,11 +74,14 @@ import static org.springframework.security.oauth2.server.authorization.authentic
public final class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
private final OAuth2AuthorizationService authorizationService;
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
// TODO Remove after removing @Deprecated OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService, JwtEncoder)
private JwtGenerator jwtGenerator;
@Deprecated
private Supplier<String> refreshTokenGenerator;
/**
* Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
@@ -95,7 +96,9 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
Assert.notNull(authorizationService, "authorizationService cannot be null");
Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
this.authorizationService = authorizationService;
this.tokenGenerator = new JwtGenerator(jwtEncoder);
this.jwtGenerator = new JwtGenerator(jwtEncoder);
this.tokenGenerator = new DelegatingOAuth2TokenGenerator(
this.jwtGenerator, new OAuth2RefreshTokenGenerator());
}
/**
@@ -124,16 +127,18 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
@Deprecated
public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
if (this.tokenGenerator instanceof JwtGenerator) {
((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer);
if (this.jwtGenerator != null) {
this.jwtGenerator.setJwtCustomizer(jwtCustomizer);
}
}
/**
* Sets the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}.
*
* @deprecated Use {@link OAuth2RefreshTokenGenerator} instead
* @param refreshTokenGenerator the {@code Supplier<String>} that generates the value for the {@link OAuth2RefreshToken}
*/
@Deprecated
public void setRefreshTokenGenerator(Supplier<String> refreshTokenGenerator) {
Assert.notNull(refreshTokenGenerator, "refreshTokenGenerator cannot be null");
this.refreshTokenGenerator = refreshTokenGenerator;
@@ -222,7 +227,20 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
// ----- Refresh token -----
OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
if (!registeredClient.getTokenSettings().isReuseRefreshTokens()) {
currentRefreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
if (this.refreshTokenGenerator != null) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
currentRefreshToken = new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
} else {
tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.REFRESH_TOKEN).build();
OAuth2Token generatedRefreshToken = this.tokenGenerator.generate(tokenContext);
if (!(generatedRefreshToken instanceof OAuth2RefreshToken)) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
"The token generator failed to generate the refresh token.", ERROR_URI);
throw new OAuth2AuthenticationException(error);
}
currentRefreshToken = (OAuth2RefreshToken) generatedRefreshToken;
}
authorizationBuilder.refreshToken(currentRefreshToken);
}
@@ -263,10 +281,4 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
}
private OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenTimeToLive);
return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
}
}

View File

@@ -70,6 +70,7 @@ import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
@@ -81,6 +82,7 @@ import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
@@ -89,6 +91,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -445,7 +448,7 @@ public class OAuth2AuthorizationCodeGrantTests {
.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
.andExpect(status().isOk());
verify(this.tokenGenerator).generate(any());
verify(this.tokenGenerator, times(2)).generate(any());
}
@Test
@@ -842,10 +845,13 @@ public class OAuth2AuthorizationCodeGrantTests {
OAuth2TokenGenerator<?> tokenGenerator() {
JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder());
jwtGenerator.setJwtCustomizer(jwtCustomizer());
return spy(new OAuth2TokenGenerator<Jwt>() {
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
OAuth2TokenGenerator<OAuth2Token> delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
return spy(new OAuth2TokenGenerator<OAuth2Token>() {
@Override
public Jwt generate(OAuth2TokenContext context) {
return jwtGenerator.generate(context);
public OAuth2Token generate(OAuth2TokenContext context) {
return delegatingTokenGenerator.generate(context);
}
});
}

View File

@@ -58,6 +58,7 @@ import org.springframework.security.crypto.password.NoOpPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
@@ -69,11 +70,13 @@ import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -262,7 +265,7 @@ public class OidcTests {
registeredClient.getClientId(), registeredClient.getClientSecret())))
.andExpect(status().isOk());
verify(this.tokenGenerator, times(2)).generate(any());
verify(this.tokenGenerator, times(3)).generate(any());
}
private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
@@ -404,10 +407,13 @@ public class OidcTests {
OAuth2TokenGenerator<?> tokenGenerator() {
JwtGenerator jwtGenerator = new JwtGenerator(new NimbusJwsEncoder(jwkSource()));
jwtGenerator.setJwtCustomizer(jwtCustomizer());
return spy(new OAuth2TokenGenerator<Jwt>() {
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
OAuth2TokenGenerator<OAuth2Token> delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
return spy(new OAuth2TokenGenerator<OAuth2Token>() {
@Override
public Jwt generate(OAuth2TokenContext context) {
return jwtGenerator.generate(context);
public OAuth2Token generate(OAuth2TokenContext context) {
return delegatingTokenGenerator.generate(context);
}
});
}

View File

@@ -0,0 +1,86 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Tests for {@link DelegatingOAuth2TokenGenerator}.
*
* @author Joe Grandja
*/
public class DelegatingOAuth2TokenGeneratorTests {
@Test
@SuppressWarnings("unchecked")
public void constructorWhenTokenGeneratorsEmptyThenThrowIllegalArgumentException() {
OAuth2TokenGenerator<OAuth2Token>[] tokenGenerators = new OAuth2TokenGenerator[0];
assertThatThrownBy(() -> new DelegatingOAuth2TokenGenerator(tokenGenerators))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokenGenerators cannot be empty");
}
@Test
public void constructorWhenTokenGeneratorsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new DelegatingOAuth2TokenGenerator(null, null))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("tokenGenerator cannot be null");
}
@Test
@SuppressWarnings("unchecked")
public void generateWhenTokenGeneratorSupportedThenReturnToken() {
OAuth2TokenGenerator<OAuth2Token> tokenGenerator1 = mock(OAuth2TokenGenerator.class);
OAuth2TokenGenerator<OAuth2Token> tokenGenerator2 = mock(OAuth2TokenGenerator.class);
OAuth2TokenGenerator<OAuth2Token> tokenGenerator3 = mock(OAuth2TokenGenerator.class);
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
"access-token", Instant.now(), Instant.now().plusSeconds(300));
when(tokenGenerator3.generate(any())).thenReturn(accessToken);
DelegatingOAuth2TokenGenerator delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(tokenGenerator1, tokenGenerator2, tokenGenerator3);
OAuth2Token token = delegatingTokenGenerator.generate(DefaultOAuth2TokenContext.builder().build());
assertThat(token).isEqualTo(accessToken);
}
@Test
@SuppressWarnings("unchecked")
public void generateWhenTokenGeneratorNotSupportedThenReturnNull() {
OAuth2TokenGenerator<OAuth2Token> tokenGenerator1 = mock(OAuth2TokenGenerator.class);
OAuth2TokenGenerator<OAuth2Token> tokenGenerator2 = mock(OAuth2TokenGenerator.class);
OAuth2TokenGenerator<OAuth2Token> tokenGenerator3 = mock(OAuth2TokenGenerator.class);
DelegatingOAuth2TokenGenerator delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(tokenGenerator1, tokenGenerator2, tokenGenerator3);
OAuth2Token token = delegatingTokenGenerator.generate(DefaultOAuth2TokenContext.builder().build());
assertThat(token).isNull();
}
}

View File

@@ -0,0 +1,68 @@
/*
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization;
import java.time.Instant;
import org.junit.Test;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link OAuth2RefreshTokenGenerator}.
*
* @author Joe Grandja
*/
public class OAuth2RefreshTokenGeneratorTests {
private final OAuth2RefreshTokenGenerator tokenGenerator = new OAuth2RefreshTokenGenerator();
@Test
public void generateWhenUnsupportedTokenTypeThenReturnNull() {
// @formatter:off
OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
.tokenType(OAuth2TokenType.ACCESS_TOKEN)
.build();
// @formatter:on
assertThat(this.tokenGenerator.generate(tokenContext)).isNull();
}
@Test
public void generateWhenRefreshTokenTypeThenReturnRefreshToken() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
// @formatter:off
OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient)
.tokenType(OAuth2TokenType.REFRESH_TOKEN)
.build();
// @formatter:on
OAuth2RefreshToken refreshToken = this.tokenGenerator.generate(tokenContext);
assertThat(refreshToken).isNotNull();
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getRefreshTokenTimeToLive());
assertThat(refreshToken.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
assertThat(refreshToken.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1));
}
}

View File

@@ -37,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -48,10 +49,12 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimsSet;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -97,10 +100,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
OAuth2TokenGenerator<OAuth2Token> delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
this.tokenGenerator = spy(new OAuth2TokenGenerator<OAuth2Token>() {
@Override
public Jwt generate(OAuth2TokenContext context) {
return jwtGenerator.generate(context);
public OAuth2Token generate(OAuth2TokenContext context) {
return delegatingTokenGenerator.generate(context);
}
});
this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
@@ -326,6 +332,40 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
});
}
@Test
public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
OAuth2AuthorizationRequest.class.getName());
OAuth2AuthorizationCodeAuthenticationToken authentication =
new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
doAnswer(answer -> {
OAuth2TokenContext context = answer.getArgument(0);
if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
} else {
return answer.callRealMethod();
}
}).when(this.tokenGenerator).generate(any());
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
});
}
@Test
public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();

View File

@@ -37,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.OAuth2TokenType;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
@@ -46,10 +47,12 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.JoseHeaderNames;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.server.authorization.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
@@ -96,10 +99,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
OAuth2TokenGenerator<OAuth2Token> delegatingTokenGenerator =
new DelegatingOAuth2TokenGenerator(jwtGenerator, refreshTokenGenerator);
this.tokenGenerator = spy(new OAuth2TokenGenerator<OAuth2Token>() {
@Override
public Jwt generate(OAuth2TokenContext context) {
return jwtGenerator.generate(context);
public OAuth2Token generate(OAuth2TokenContext context) {
return delegatingTokenGenerator.generate(context);
}
});
this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
@@ -551,6 +557,40 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
});
}
@Test
public void authenticateWhenRefreshTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
.tokenSettings(TokenSettings.builder().reuseRefreshTokens(false).build())
.build();
OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
when(this.authorizationService.findByToken(
eq(authorization.getRefreshToken().getToken().getTokenValue()),
eq(OAuth2TokenType.REFRESH_TOKEN)))
.thenReturn(authorization);
OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
doAnswer(answer -> {
OAuth2TokenContext context = answer.getArgument(0);
if (OAuth2TokenType.REFRESH_TOKEN.equals(context.getTokenType())) {
return null;
} else {
return answer.callRealMethod();
}
}).when(this.tokenGenerator).generate(any());
assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
.isInstanceOf(OAuth2AuthenticationException.class)
.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
.satisfies(error -> {
assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
assertThat(error.getDescription()).contains("The token generator failed to generate the refresh token.");
});
}
@Test
public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();