Skip to content

Commit

Permalink
Support header filtering in web data binding
Browse files Browse the repository at this point in the history
Closes gh-34039
  • Loading branch information
rstoyanchev committed Dec 11, 2024
1 parent 70c326e commit 8aeced9
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

import reactor.core.publisher.Mono;

Expand All @@ -41,12 +43,40 @@
*/
public class ExtendedWebExchangeDataBinder extends WebExchangeDataBinder {

private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");


private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);


public ExtendedWebExchangeDataBinder(@Nullable Object target, String objectName) {
super(target, objectName);
}


/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}

/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}


@Override
public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
return super.getValuesToBind(exchange).doOnNext(map -> {
Expand All @@ -56,10 +86,13 @@ public Mono<Map<String, Object>> getValuesToBind(ServerWebExchange exchange) {
}
HttpHeaders headers = exchange.getRequest().getHeaders();
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String name = entry.getKey();
if (!this.headerPredicate.test(entry.getKey())) {
continue;
}
List<String> values = entry.getValue();
if (!CollectionUtils.isEmpty(values)) {
// For constructor args with @BindParam mapped to the actual header name
String name = entry.getKey();
addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values));
// Also adapt to Java conventions for setters
name = StringUtils.uncapitalize(entry.getKey().replace("-", ""));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,24 @@ void bindUriVarsAndHeadersAddedConditionally() throws Exception {
assertThat(target.getAge()).isEqualTo(25);
}

@Test
void headerPredicate() throws Exception {
MockServerHttpRequest request = MockServerHttpRequest.get("/path")
.header("Priority", "u1")
.header("Some-Int-Array", "1")
.header("Another-Int-Array", "1")
.build();

MockServerWebExchange exchange = MockServerWebExchange.from(request);

BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class);
ExtendedWebExchangeDataBinder binder = (ExtendedWebExchangeDataBinder) context.createDataBinder(exchange, null, "", null);
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));

Map<String, Object> map = binder.getValuesToBind(exchange).block();
assertThat(map).containsExactlyInAnyOrderEntriesOf(Map.of("someIntArray", "1", "Some-Int-Array", "1"));
}

private BindingContext createBindingContext(String methodName, Class<?>... parameterTypes) throws Exception {
Object handler = new InitBinderHandler();
Method method = handler.getClass().getMethod(methodName, parameterTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;

import jakarta.servlet.ServletRequest;
import jakarta.servlet.http.HttpServletRequest;

import org.springframework.beans.MutablePropertyValues;
import org.springframework.lang.Nullable;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.ServletRequestDataBinder;
import org.springframework.web.bind.WebDataBinder;
import org.springframework.web.servlet.HandlerMapping;
Expand All @@ -51,6 +53,12 @@
*/
public class ExtendedServletRequestDataBinder extends ServletRequestDataBinder {

private static final Set<String> FILTERED_HEADER_NAMES = Set.of("Priority");


private Predicate<String> headerPredicate = name -> !FILTERED_HEADER_NAMES.contains(name);


/**
* Create a new instance, with default object name.
* @param target the target object to bind onto (or {@code null}
Expand All @@ -73,6 +81,29 @@ public ExtendedServletRequestDataBinder(@Nullable Object target, String objectNa
}


/**
* Add a Predicate that filters the header names to use for data binding.
* Multiple predicates are combined with {@code AND}.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void addHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = this.headerPredicate.and(headerPredicate);
}

/**
* Set the Predicate that filters the header names to use for data binding.
* <p>Note that this method resets any previous predicates that may have been
* set, including headers excluded by default such as the RFC 9218 defined
* "Priority" header.
* @param headerPredicate the predicate to add
* @since 6.2.1
*/
public void setHeaderPredicate(Predicate<String> headerPredicate) {
this.headerPredicate = headerPredicate;
}


@Override
protected ServletRequestValueResolver createValueResolver(ServletRequest request) {
return new ExtendedServletRequestValueResolver(request, this);
Expand All @@ -93,7 +124,7 @@ protected void addBindValues(MutablePropertyValues mpvs, ServletRequest request)
String name = names.nextElement();
Object value = getHeaderValue(httpRequest, name);
if (value != null) {
name = name.replace("-", "");
name = StringUtils.uncapitalize(name.replace("-", ""));
addValueIfNotPresent(mpvs, "Header", name, value);
}
}
Expand All @@ -118,7 +149,11 @@ private static void addValueIfNotPresent(MutablePropertyValues mpvs, String labe
}

@Nullable
private static Object getHeaderValue(HttpServletRequest request, String name) {
private Object getHeaderValue(HttpServletRequest request, String name) {
if (!this.headerPredicate.test(name)) {
return null;
}

Enumeration<String> valuesEnum = request.getHeaders(name);
if (!valuesEnum.hasMoreElements()) {
return null;
Expand All @@ -141,7 +176,7 @@ private static Object getHeaderValue(HttpServletRequest request, String name) {
/**
* Resolver of values that looks up URI path variables.
*/
private static class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {
private class ExtendedServletRequestValueResolver extends ServletRequestValueResolver {

ExtendedServletRequestValueResolver(ServletRequest request, WebDataBinder dataBinder) {
super(request, dataBinder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import java.util.Map;

import jakarta.servlet.ServletRequest;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.testfixture.beans.TestBean;
import org.springframework.core.ResolvableType;
import org.springframework.web.bind.ServletRequestDataBinder;
Expand Down Expand Up @@ -102,6 +104,22 @@ void uriVarsAndHeadersAddedConditionally() {
assertThat(target.getAge()).isEqualTo(25);
}

@Test
void headerPredicate() {
TestBinder binder = new TestBinder();
binder.addHeaderPredicate(name -> !name.equalsIgnoreCase("Another-Int-Array"));

MutablePropertyValues mpvs = new MutablePropertyValues();
request.addHeader("Priority", "u1");
request.addHeader("Some-Int-Array", "1");
request.addHeader("Another-Int-Array", "1");

binder.addBindValues(mpvs, request);

assertThat(mpvs.size()).isEqualTo(1);
assertThat(mpvs.get("someIntArray")).isEqualTo("1");
}

@Test
void noUriTemplateVars() {
TestBean target = new TestBean();
Expand All @@ -116,4 +134,17 @@ void noUriTemplateVars() {
private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) {
}


private static class TestBinder extends ExtendedServletRequestDataBinder {

public TestBinder() {
super(null);
}

@Override
public void addBindValues(MutablePropertyValues mpvs, ServletRequest request) {
super.addBindValues(mpvs, request);
}
}

}

0 comments on commit 8aeced9

Please sign in to comment.