diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 7d4327910b4..3e61eb51154 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -40,6 +40,8 @@ import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeReactiveAuthenticationManager; import org.springframework.security.oauth2.client.authentication.OAuth2LoginReactiveAuthenticationManager; +import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; import org.springframework.security.oauth2.client.endpoint.WebClientReactiveAuthorizationCodeTokenResponseClient; import org.springframework.security.oauth2.client.oidc.authentication.OidcAuthorizationCodeReactiveAuthenticationManager; import org.springframework.security.oauth2.client.oidc.userinfo.OidcReactiveOAuth2UserService; @@ -619,14 +621,13 @@ private ReactiveAuthenticationManager getAuthenticationManager() { } private ReactiveAuthenticationManager createDefault() { - WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient(); - ReactiveAuthenticationManager result = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService()); + ReactiveAuthenticationManager result = new OAuth2LoginReactiveAuthenticationManager(getAccessTokenResponseClient(), getOauth2UserService()); boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent( "org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader()); if (oidcAuthenticationProviderEnabled) { OidcAuthorizationCodeReactiveAuthenticationManager oidc = - new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService()); + new OidcAuthorizationCodeReactiveAuthenticationManager(getAccessTokenResponseClient(), getOidcUserService()); ResolvableType type = ResolvableType.forClassWithGenerics( ReactiveJwtDecoderFactory.class, ClientRegistration.class); ReactiveJwtDecoderFactory jwtDecoderFactory = getBeanOrNull(type); @@ -786,6 +787,15 @@ private Map getLinks() { return result; } + private ReactiveOAuth2AccessTokenResponseClient getAccessTokenResponseClient() { + ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class); + ReactiveOAuth2AccessTokenResponseClient bean = getBeanOrNull(type); + if (bean == null) { + return new WebClientReactiveAuthorizationCodeTokenResponseClient(); + } + return bean; + } + private ReactiveClientRegistrationRepository getClientRegistrationRepository() { if (this.clientRegistrationRepository == null) { this.clientRegistrationRepository = getBeanOrNull(ReactiveClientRegistrationRepository.class); diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java index aca16ac5e01..0b667e4796c 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java @@ -258,6 +258,7 @@ public void oauth2LoginWhenCustomJwtDecoderFactoryThenUsed() { .expectStatus().is3xxRedirection(); verify(config.jwtDecoderFactory).createDecoder(any()); + verify(tokenResponseClient).getTokenResponse(any()).thenReturn(Mono.just(accessTokenResponse)); } @Configuration @@ -298,6 +299,11 @@ public ReactiveJwtDecoderFactory jwtDecoderFactory() { return jwtDecoderFactory; } + @Bean + public ReactiveOAuth2AccessTokenResponseClient oAuth2AccessTokenResponseClient() { + return tokenResponseClient; + } + private static class JwtDecoderFactory implements ReactiveJwtDecoderFactory { @Override