diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java index b5c0db06bfd..af2f56e0cd5 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurer.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -79,6 +79,7 @@ * * * @author Joe Grandja + * @author Parikshit Dutta * @since 5.1 * @see OAuth2AuthorizationRequestRedirectFilter * @see OAuth2AuthorizationCodeGrantFilter @@ -256,6 +257,10 @@ private OAuth2AuthorizationCodeGrantFilter createAuthorizationCodeGrantFilter(B if (this.authorizationRequestRepository != null) { authorizationCodeGrantFilter.setAuthorizationRequestRepository(this.authorizationRequestRepository); } + RequestCache requestCache = builder.getSharedObject(RequestCache.class); + if (requestCache != null) { + authorizationCodeGrantFilter.setRequestCache(requestCache); + } return authorizationCodeGrantFilter; } diff --git a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java index 86e4b4e3c41..ffc06ee6b02 100644 --- a/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java +++ b/config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2ClientConfigurerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,6 +75,7 @@ * Tests for {@link OAuth2ClientConfigurer}. * * @author Joe Grandja + * @author Parikshit Dutta */ public class OAuth2ClientConfigurerTests { private static ClientRegistrationRepository clientRegistrationRepository; @@ -208,6 +209,43 @@ public void configureWhenRequestCacheProvidedAndClientAuthorizationRequiredExcep verify(requestCache).saveRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); } + @Test + public void configureWhenRequestCacheProvidedAndClientAuthorizationSucceedsThenRequestCacheUsed() throws Exception { + this.spring.register(OAuth2ClientConfig.class).autowire(); + + // Setup the Authorization Request in the session + Map attributes = new HashMap<>(); + attributes.put(OAuth2ParameterNames.REGISTRATION_ID, this.registration1.getRegistrationId()); + OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() + .authorizationUri(this.registration1.getProviderDetails().getAuthorizationUri()) + .clientId(this.registration1.getClientId()) + .redirectUri("http://localhost/client-1") + .state("state") + .attributes(attributes) + .build(); + + AuthorizationRequestRepository authorizationRequestRepository = + new HttpSessionOAuth2AuthorizationRequestRepository(); + MockHttpServletRequest request = new MockHttpServletRequest("GET", ""); + MockHttpServletResponse response = new MockHttpServletResponse(); + authorizationRequestRepository.saveAuthorizationRequest(authorizationRequest, request, response); + + MockHttpSession session = (MockHttpSession) request.getSession(); + + String principalName = "user1"; + TestingAuthenticationToken authentication = new TestingAuthenticationToken(principalName, "password"); + + this.mockMvc.perform(get("/client-1") + .param(OAuth2ParameterNames.CODE, "code") + .param(OAuth2ParameterNames.STATE, "state") + .with(authentication(authentication)) + .session(session)) + .andExpect(status().is3xxRedirection()) + .andExpect(redirectedUrl("http://localhost/client-1")); + + verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + } + // gh-5521 @Test public void configureWhenCustomAuthorizationRequestResolverSetThenAuthorizationRequestIncludesCustomParameters() throws Exception { diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java index 4f8aaefbafa..8d2f157c457 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilter.java @@ -83,6 +83,7 @@ * * * @author Joe Grandja + * @author Parikshit Dutta * @since 5.1 * @see OAuth2AuthorizationCodeAuthenticationToken * @see OAuth2AuthorizationCodeAuthenticationProvider @@ -104,7 +105,7 @@ public class OAuth2AuthorizationCodeGrantFilter extends OncePerRequestFilter { new HttpSessionOAuth2AuthorizationRequestRepository(); private final AuthenticationDetailsSource authenticationDetailsSource = new WebAuthenticationDetailsSource(); private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); - private final RequestCache requestCache = new HttpSessionRequestCache(); + private RequestCache requestCache = new HttpSessionRequestCache(); /** * Constructs an {@code OAuth2AuthorizationCodeGrantFilter} using the provided parameters. @@ -134,6 +135,18 @@ public final void setAuthorizationRequestRepository(AuthorizationRequestReposito this.authorizationRequestRepository = authorizationRequestRepository; } + /** + * Sets the {@link RequestCache} used for loading a previously saved request (if available) + * and replaying it after completing the processing of the OAuth 2.0 Authorization Response. + * + * @since 5.4 + * @param requestCache the cache used for loading a previously saved request (if available) + */ + public final void setRequestCache(RequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + } + @Override protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java index 39b3011f03b..d1dba0d8997 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationCodeGrantFilterTests.java @@ -72,6 +72,7 @@ * Tests for {@link OAuth2AuthorizationCodeGrantFilter}. * * @author Joe Grandja + * @author Parikshit Dutta */ public class OAuth2AuthorizationCodeGrantFilterTests { private ClientRegistration registration1; @@ -130,6 +131,12 @@ public void setAuthorizationRequestRepositoryWhenAuthorizationRequestRepositoryI .isInstanceOf(IllegalArgumentException.class); } + @Test + public void setRequestCacheWhenRequestCacheIsNullThenThrowIllegalArgumentException() { + assertThatThrownBy(() -> this.filter.setRequestCache(null)) + .isInstanceOf(IllegalArgumentException.class); + } + @Test public void doFilterWhenNotAuthorizationResponseThenNotProcessed() throws Exception { String requestUri = "/path"; @@ -326,6 +333,28 @@ public void doFilterWhenAuthorizationSucceedsAndHasSavedRequestThenRedirectToSav assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); } + @Test + public void doFilterWhenAuthorizationSucceedsAndRequestCacheConfiguredThenRequestCacheUsed() throws Exception { + MockHttpServletRequest authorizationRequest = createAuthorizationRequest("/callback/client-1"); + MockHttpServletRequest authorizationResponse = createAuthorizationResponse(authorizationRequest); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChain filterChain = mock(FilterChain.class); + this.setUpAuthorizationRequest(authorizationRequest, response, this.registration1); + this.setUpAuthenticationResult(this.registration1); + + RequestCache requestCache = spy(HttpSessionRequestCache.class); + this.filter.setRequestCache(requestCache); + + authorizationRequest.setRequestURI("/saved-request"); + requestCache.saveRequest(authorizationRequest, response); + + this.filter.doFilter(authorizationResponse, response, filterChain); + + verify(requestCache).getRequest(any(HttpServletRequest.class), any(HttpServletResponse.class)); + assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/saved-request"); + } + @Test public void doFilterWhenAuthorizationSucceedsAndAnonymousAccessThenAuthorizedClientSavedToHttpSession() throws Exception { AnonymousAuthenticationToken anonymousPrincipal =