Skip to content

Commit

Permalink
MockMvc supports FilterRegistration filter init
Browse files Browse the repository at this point in the history
Closes gh-33252
  • Loading branch information
rstoyanchev committed Jul 30, 2024
1 parent 3845391 commit d2225c2
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.mock.web;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.FilterRegistration;

import org.springframework.lang.Nullable;

/**
* Mock implementation of {@link FilterRegistration}.
*
* @author Rossen Stoyanchev
* @since 6.2
*/
public class MockFilterRegistration implements FilterRegistration {

private final String name;

private final String className;

private final Map<String, String> initParameters = new LinkedHashMap<>();

private final List<String> servletNames = new ArrayList<>();

private final List<String> urlPatterns = new ArrayList<>();


public MockFilterRegistration(String className) {
this(className, "");
}

public MockFilterRegistration(String className, String name) {
this.name = name;
this.className = className;
}


@Override
public String getName() {
return this.name;
}

@Nullable
@Override
public String getClassName() {
return this.className;
}

@Override
public boolean setInitParameter(String name, String value) {
return (this.initParameters.putIfAbsent(name, value) != null);
}

@Nullable
@Override
public String getInitParameter(String name) {
return this.initParameters.get(name);
}

@Override
public Set<String> setInitParameters(Map<String, String> initParameters) {
Set<String> existingParameterNames = new LinkedHashSet<>();
for (Map.Entry<String, String> entry : initParameters.entrySet()) {
if (this.initParameters.get(entry.getKey()) != null) {
existingParameterNames.add(entry.getKey());
}
}
if (existingParameterNames.isEmpty()) {
this.initParameters.putAll(initParameters);
}
return existingParameterNames;
}

@Override
public Map<String, String> getInitParameters() {
return Collections.unmodifiableMap(this.initParameters);
}

@Override
public void addMappingForServletNames(
EnumSet<DispatcherType> dispatcherTypes, boolean isMatchAfter, String... servletNames) {

this.servletNames.addAll(Arrays.asList(servletNames));
}

@Override
public Collection<String> getServletNameMappings() {
return Collections.unmodifiableCollection(this.servletNames);
}

@Override
public void addMappingForUrlPatterns(
EnumSet<DispatcherType> dispatcherTypes, boolean isMatchAfter, String... urlPatterns) {

this.urlPatterns.addAll(Arrays.asList(urlPatterns));
}

@Override
public Collection<String> getUrlPatternMappings() {
return Collections.unmodifiableCollection(this.urlPatterns);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ public class MockServletContext implements ServletContext {
@Nullable
private String responseCharacterEncoding;

private final Map<String, FilterRegistration> filterRegistrations = new LinkedHashMap<>();

private final Map<String, MediaType> mimeTypes = new LinkedHashMap<>();


Expand Down Expand Up @@ -604,6 +606,25 @@ public String getResponseCharacterEncoding() {
return this.responseCharacterEncoding;
}

/**
* Add a {@link FilterRegistration}.
* @since 6.2
*/
public void addFilterRegistration(FilterRegistration registration) {
this.filterRegistrations.put(registration.getName(), registration);
}

@Override
@Nullable
public FilterRegistration getFilterRegistration(String filterName) {
return this.filterRegistrations.get(filterName);
}

@Override
public Map<String, ? extends FilterRegistration> getFilterRegistrations() {
return Collections.unmodifiableMap(this.filterRegistrations);
}


//---------------------------------------------------------------------
// Unsupported Servlet 3.0 registration methods
Expand Down Expand Up @@ -678,25 +699,6 @@ public <T extends Filter> T createFilter(Class<T> c) throws ServletException {
throw new UnsupportedOperationException();
}

/**
* This method always returns {@code null}.
* @see jakarta.servlet.ServletContext#getFilterRegistration(java.lang.String)
*/
@Override
@Nullable
public FilterRegistration getFilterRegistration(String filterName) {
return null;
}

/**
* This method always returns an {@linkplain Collections#emptyMap empty map}.
* @see jakarta.servlet.ServletContext#getFilterRegistrations()
*/
@Override
public Map<String, ? extends FilterRegistration> getFilterRegistrations() {
return Collections.emptyMap();
}

@Override
public void addListener(Class<? extends EventListener> listenerClass) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,8 @@

import org.springframework.lang.Nullable;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.mock.web.MockFilterRegistration;
import org.springframework.mock.web.MockServletContext;
import org.springframework.util.Assert;
import org.springframework.web.util.UrlPathHelper;

Expand Down Expand Up @@ -98,17 +100,27 @@ public MockMvcFilterDecorator(
Assert.notNull(delegate, "filter cannot be null");
Assert.notNull(urlPatterns, "urlPatterns cannot be null");
this.delegate = delegate;
this.filterConfigInitializer = getFilterConfigInitializer(filterName, initParams);
this.filterConfigInitializer = getFilterConfigInitializer(delegate, filterName, initParams);
this.dispatcherTypes = dispatcherTypes;
this.hasPatterns = initPatterns(urlPatterns);
}

private static Function<ServletContext, FilterConfig> getFilterConfigInitializer(
@Nullable String filterName, @Nullable Map<String, String> initParams) {
Filter delegate, @Nullable String filterName, @Nullable Map<String, String> initParams) {

String className = delegate.getClass().getName();

return servletContext -> {
MockFilterConfig filterConfig = (filterName != null ?
new MockFilterConfig(servletContext, filterName) : new MockFilterConfig(servletContext));
MockServletContext mockServletContext = (MockServletContext) servletContext;
MockFilterConfig filterConfig;
if (filterName != null) {
filterConfig = new MockFilterConfig(servletContext, filterName);
mockServletContext.addFilterRegistration(new MockFilterRegistration(className, filterName));
}
else {
filterConfig = new MockFilterConfig(servletContext);
mockServletContext.addFilterRegistration(new MockFilterRegistration(className));
}
if (initParams != null) {
initParams.forEach(filterConfig::addInitParameter);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,7 @@
import org.springframework.http.converter.json.SpringHandlerInstantiator;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.stereotype.Controller;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
Expand Down Expand Up @@ -133,12 +134,17 @@ void addFilterWithInitParams() throws ServletException {
Filter filter = mock(Filter.class);
ArgumentCaptor<FilterConfig> captor = ArgumentCaptor.forClass(FilterConfig.class);

MockMvcBuilders.standaloneSetup(new PersonController())
.addFilter(filter, null, Map.of("p", "v"), EnumSet.of(DispatcherType.REQUEST), "/")
MockMvc mockMvc = MockMvcBuilders.standaloneSetup(new PersonController())
.addFilter(filter, "testFilter", Map.of("p", "v"), EnumSet.of(DispatcherType.REQUEST), "/")
.build();

verify(filter, times(1)).init(captor.capture());
assertThat(captor.getValue().getInitParameter("p")).isEqualTo("v");

// gh-33252

assertThat(mockMvc.getDispatcherServlet().getServletContext().getFilterRegistrations())
.hasSize(1).containsKey("testFilter");
}

@Test // SPR-13375
Expand Down

0 comments on commit d2225c2

Please sign in to comment.