diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java index 318cf6a6..b48f6694 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java @@ -89,11 +89,14 @@ public final class JwtGenerator implements OAuth2TokenGenerator { Instant issuedAt = Instant.now(); Instant expiresAt; + JwsHeader.Builder headersBuilder; if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) { // TODO Allow configuration for ID Token time-to-live expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + headersBuilder = JwsHeader.with(registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm()); } else { expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive()); + headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256); } // @formatter:off @@ -125,11 +128,9 @@ public final class JwtGenerator implements OAuth2TokenGenerator { } // @formatter:on - JwsHeader.Builder jwsHeaderBuilder = JwsHeader.with(SignatureAlgorithm.RS256); - if (this.jwtCustomizer != null) { // @formatter:off - JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(jwsHeaderBuilder, claimsBuilder) + JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder) .registeredClient(context.getRegisteredClient()) .principal(context.getPrincipal()) .authorizationServerContext(context.getAuthorizationServerContext()) @@ -148,7 +149,7 @@ public final class JwtGenerator implements OAuth2TokenGenerator { this.jwtCustomizer.customize(jwtContext); } - JwsHeader jwsHeader = jwsHeaderBuilder.build(); + JwsHeader jwsHeader = headersBuilder.build(); JwtClaimsSet claims = claimsBuilder.build(); Jwt jwt = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims)); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java index b4f672cf..8534ee60 100644 --- a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java @@ -201,9 +201,6 @@ public class JwtGeneratorTests { ArgumentCaptor jwtEncoderParametersCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class); verify(this.jwtEncoder).encode(jwtEncoderParametersCaptor.capture()); - JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader(); - assertThat(jwsHeader.getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256); - JwtClaimsSet jwtClaimsSet = jwtEncoderParametersCaptor.getValue().getClaims(); assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getAuthorizationServerContext().getIssuer()); assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName()); @@ -211,14 +208,20 @@ public class JwtGeneratorTests { Instant issuedAt = Instant.now(); Instant expiresAt; + JwsHeader.Builder headersBuilder; if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive()); + headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256); } else { expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES); + headersBuilder = JwsHeader.with(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm()); } assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1)); assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1)); + JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader(); + assertThat(jwsHeader.getAlgorithm()).isEqualTo(headersBuilder.build().getAlgorithm()); + if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) { assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));