From 887cb9992638dc2b29d4cbc1ab18d87c56c80030 Mon Sep 17 00:00:00 2001 From: Josh Cummings Date: Wed, 8 Apr 2020 16:14:14 -0600 Subject: [PATCH] Saml2AuthenticationRequestFilter Tests To confirm behavior still works as expected after making related changes. Issue gh-8359 --- ...ebSsoAuthenticationRequestFilterTests.java | 71 +++++++++++++++++-- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java index 4b5bb37bb4b..8c51ce7b31a 100644 --- a/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java +++ b/saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java @@ -16,22 +16,29 @@ package org.springframework.security.saml2.provider.service.servlet.filter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import javax.servlet.ServletException; + import org.junit.Before; import org.junit.Test; + import org.springframework.mock.web.MockFilterChain; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory; +import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration; import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository; import org.springframework.web.util.HtmlUtils; import org.springframework.web.util.UriUtils; -import javax.servlet.ServletException; -import java.io.IOException; -import java.nio.charset.StandardCharsets; - import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST; import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential; @@ -41,6 +48,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests { private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO"; private Saml2WebSsoAuthenticationRequestFilter filter; private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class); + private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class); private MockHttpServletRequest request; private MockHttpServletResponse response; private MockFilterChain filterChain; @@ -147,4 +155,59 @@ public void doFilterWhenPostFormDataIsPresent() throws Exception { .contains("value=\""+relayStateEncoded+"\""); } + @Test + public void doFilterWhenSetAuthenticationRequestFactoryThenUses() throws Exception { + RelyingPartyRegistration relyingParty = this.rpBuilder + .providerDetails(c -> c.binding(POST)) + .build(); + Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class); + when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri"); + when(authenticationRequest.getRelayState()).thenReturn("relay"); + when(authenticationRequest.getSamlRequest()).thenReturn("saml"); + when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty); + when(this.factory.createPostAuthenticationRequest(any())) + .thenReturn(authenticationRequest); + + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter + (this.repository); + filter.setAuthenticationRequestFactory(this.factory); + filter.doFilterInternal(this.request, this.response, this.filterChain); + assertThat(this.response.getContentAsString()) + .contains("
") + .contains(" filter.setRedirectMatcher(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void setAuthenticationRequestFactoryWhenNullThenException() { + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository); + assertThatCode(() -> filter.setAuthenticationRequestFactory(null)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void doFilterWhenRequestMatcherFailsThenSkipsFilter() throws Exception { + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter + (this.repository); + filter.setRedirectMatcher(request -> false); + filter.doFilter(this.request, this.response, this.filterChain); + verifyNoInteractions(this.repository); + } + + @Test + public void doFilterWhenRelyingPartyRegistrationNotFoundThenUnauthorized() throws Exception { + Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter + (this.repository); + filter.doFilter(this.request, this.response, this.filterChain); + assertThat(this.response.getStatus()).isEqualTo(401); + } }