Skip to content

Commit

Permalink
Merge pull request #1044 from moulalis/master_UNDERTOW-1837
Browse files Browse the repository at this point in the history
[UNDERTOW-1837] ServletRequest#getLocalPort(), getLocalAddr() and get…
  • Loading branch information
fl4via authored Mar 30, 2021
2 parents 013802f + 0a80b5b commit 53bc0c0
Show file tree
Hide file tree
Showing 7 changed files with 420 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ public class ForwardedHandler implements HttpHandler {
public static final String HOST = "host";
public static final String PROTO = "proto";
private static final String UNKNOWN = "unknown";


private static final boolean DEFAULT_CHANGE_LOCAL_ADDR_PORT = Boolean.getBoolean("io.undertow.forwarded.change-local-addr-port");
private static final String CHANGE_LOCAL_ADDR_PORT = "change-local-addr-port";
private final boolean isChangeLocalAddrPort;
private final HttpHandler next;

public ForwardedHandler(HttpHandler next) {
this(next, DEFAULT_CHANGE_LOCAL_ADDR_PORT);
}
public ForwardedHandler(HttpHandler next, boolean isChangeLocalAddrPort) {
this.next = next;
this.isChangeLocalAddrPort = isChangeLocalAddrPort;
}


@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
HeaderValues forwarded = exchange.getRequestHeaders().get(Headers.FORWARDED);
Expand All @@ -54,11 +60,13 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {

if (host != null) {
exchange.getRequestHeaders().put(Headers.HOST, host);
exchange.setDestinationAddress(InetSocketAddress.createUnresolved(exchange.getHostName(), exchange.getHostPort()));
if (isChangeLocalAddrPort) {
exchange.setDestinationAddress(InetSocketAddress.createUnresolved(exchange.getHostName(), exchange.getHostPort()));
}
} else if (by != null) {
//we only use 'by' if the host is null
InetSocketAddress destAddress = parseAddress(by);
if (destAddress != null) {
if (destAddress != null && isChangeLocalAddrPort) {
exchange.setDestinationAddress(destAddress);
}
}
Expand Down Expand Up @@ -242,12 +250,6 @@ private enum SearchingFor {
START_OF_NAME, EQUALS_SIGN, START_OF_VALUE, LAST_QUOTE, END_OF_VALUE;
}

public static final HandlerWrapper WRAPPER = new HandlerWrapper() {
@Override
public HttpHandler wrap(HttpHandler handler) {
return new ForwardedHandler(handler);
}
};

@Override
public String toString() {
Expand All @@ -264,7 +266,9 @@ public String name() {

@Override
public Map<String, Class<?>> parameters() {
return Collections.emptyMap();
Map<String, Class<?>> params = new HashMap<>();
params.put(CHANGE_LOCAL_ADDR_PORT, boolean.class);
return params;
}

@Override
Expand All @@ -274,12 +278,26 @@ public Set<String> requiredParameters() {

@Override
public String defaultParameter() {
return null;
return CHANGE_LOCAL_ADDR_PORT;
}

@Override
public HandlerWrapper build(Map<String, Object> config) {
return WRAPPER;
Boolean isChangeLocalAddrPort = (Boolean) config.get(CHANGE_LOCAL_ADDR_PORT);
return new Wrapper(isChangeLocalAddrPort == null ? DEFAULT_CHANGE_LOCAL_ADDR_PORT : isChangeLocalAddrPort);
}
}
private static class Wrapper implements HandlerWrapper {

private final boolean isChangeLocalAddrPort;

private Wrapper(boolean isChangeLocalAddrPort) {
this.isChangeLocalAddrPort = isChangeLocalAddrPort;
}

@Override
public HttpHandler wrap(HttpHandler handler) {
return new ForwardedHandler(handler, isChangeLocalAddrPort);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
Expand All @@ -48,8 +49,19 @@ public class ProxyPeerAddressHandler implements HttpHandler {

private final HttpHandler next;

private static final boolean DEFAULT_CHANGE_LOCAL_ADDR_PORT = Boolean.getBoolean("io.undertow.forwarded.change-local-addr-port");

private static final String CHANGE_LOCAL_ADDR_PORT = "change-local-addr-port";

private final boolean isChangeLocalAddrPort;

public ProxyPeerAddressHandler(HttpHandler next) {
this(next, DEFAULT_CHANGE_LOCAL_ADDR_PORT);
}

public ProxyPeerAddressHandler(HttpHandler next, boolean isChangeLocalAddrPort) {
this.next = next;
this.isChangeLocalAddrPort = isChangeLocalAddrPort;
}

@Override
Expand Down Expand Up @@ -110,7 +122,9 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
}
}
exchange.getRequestHeaders().put(Headers.HOST, hostHeader);
exchange.setDestinationAddress(InetSocketAddress.createUnresolved(value, port));
if (isChangeLocalAddrPort) {
exchange.setDestinationAddress(InetSocketAddress.createUnresolved(value, port));
}
}
next.handleRequest(exchange);
}
Expand Down Expand Up @@ -142,7 +156,9 @@ public String name() {

@Override
public Map<String, Class<?>> parameters() {
return Collections.emptyMap();
Map<String, Class<?>> params = new HashMap<>();
params.put(CHANGE_LOCAL_ADDR_PORT, boolean.class);
return params;
}

@Override
Expand All @@ -152,20 +168,27 @@ public Set<String> requiredParameters() {

@Override
public String defaultParameter() {
return null;
return CHANGE_LOCAL_ADDR_PORT;
}

@Override
public HandlerWrapper build(Map<String, Object> config) {
return new Wrapper();
Boolean isChangeLocalAddrPort = (Boolean) config.get(CHANGE_LOCAL_ADDR_PORT);
return new Wrapper(isChangeLocalAddrPort == null ? DEFAULT_CHANGE_LOCAL_ADDR_PORT : isChangeLocalAddrPort);
}

}

private static class Wrapper implements HandlerWrapper {
private final boolean isChangeLocalAddrPort;

private Wrapper(boolean isChangeLocalAddrPort) {
this.isChangeLocalAddrPort = isChangeLocalAddrPort;
}

@Override
public HttpHandler wrap(HttpHandler handler) {
return new ProxyPeerAddressHandler(handler);
return new ProxyPeerAddressHandler(handler, isChangeLocalAddrPort);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class ForwardedHandlerTestCase {

@BeforeClass
public static void setup() {
final boolean DEFAULT_CHANGE_LOCAL_ADDR_PORT = Boolean.TRUE;
DefaultServer.setRootHandler(new ForwardedHandler(new HttpHandler() {
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
Expand All @@ -44,7 +45,7 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
+ "|" + toJreNormalizedString(exchange.getDestinationAddress())
+ "|" + toJreNormalizedString(exchange.getSourceAddress()));
}
}));
}, DEFAULT_CHANGE_LOCAL_ADDR_PORT));
}

private static String toJreNormalizedString(InetSocketAddress address) {
Expand Down Expand Up @@ -147,6 +148,7 @@ public void testForwardedHandler() throws IOException {
Assert.assertEquals( "foo.com", res[1]);
Assert.assertEquals( "foo.com:80", res[2]);
Assert.assertEquals( "/9.9.9.9:2343", res[3]);

}

private static String[] run(String ... headers) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* JBoss, Home of Professional Open Source.
* Copyright 2021 Red Hat, Inc., and individual contributors
* as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.undertow.servlet.test;

import io.undertow.server.HttpHandler;
import io.undertow.server.handlers.ForwardedHandler;
import io.undertow.server.handlers.PathHandler;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ServletContainer;
import io.undertow.servlet.api.ServletInfo;
import io.undertow.servlet.test.constant.GenericServletConstants;
import io.undertow.servlet.test.util.ProxyPeerXForwardedHandlerServlet;
import io.undertow.servlet.test.util.TestClassIntrospector;
import io.undertow.testutils.DefaultServer;
import io.undertow.testutils.ProxyIgnore;
import io.undertow.testutils.TestHttpClient;
import io.undertow.util.Headers;
import io.undertow.util.StatusCodes;
import org.apache.http.HttpEntity;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.util.EntityUtils;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;

import javax.servlet.ServletException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

/**
* @author Moulali Shikalwadi
*/
@RunWith(DefaultServer.class)
@ProxyIgnore
public class ProxyForwardedTestCase {
protected static int PORT;

@BeforeClass
public static void setup() throws ServletException {
PORT = DefaultServer.getHostPort("default");
final PathHandler root = new PathHandler();
final ServletContainer container = ServletContainer.Factory.newInstance();

ServletInfo s = new ServletInfo("servlet", ProxyPeerXForwardedHandlerServlet.class)
.addMapping("/forwardedHandler");

DeploymentInfo builder = new DeploymentInfo()
.setClassLoader(SimpleServletTestCase.class.getClassLoader())
.setContextPath("/servletContext")
.setClassIntrospecter(TestClassIntrospector.INSTANCE)
.setDeploymentName("servletContext.war")
.addServlet(s);

DeploymentManager manager = container.addDeployment(builder);
manager.deploy();
HttpHandler startHandler = manager.start();
startHandler = new ForwardedHandler(startHandler, false);
root.addPrefixPath(builder.getContextPath(), startHandler);

DefaultServer.setRootHandler(root);
}


@Test
public void testForwardedHandler() throws IOException {
TestHttpClient client = new TestHttpClient();
try {

HttpGet getForwardedHandler = new HttpGet(DefaultServer.getDefaultServerURL() + "/servletContext/forwardedHandler");
getForwardedHandler.addHeader(Headers.FORWARDED_STRING, "for=192.0.2.43");
getForwardedHandler.addHeader(Headers.FORWARDED_STRING, "by=203.0.113.60");
getForwardedHandler.addHeader(Headers.FORWARDED_STRING, "proto=http");
getForwardedHandler.addHeader(Headers.FORWARDED_STRING, "host=192.0.2.10:8888");
HttpResponse result = client.execute(getForwardedHandler);
HttpEntity entity = result.getEntity();
String results = EntityUtils.toString(entity);
Map<String, String> map = convertWithStream(results);
Socket socket = new Socket();
socket.connect(new InetSocketAddress(DefaultServer.getHostAddress(), PORT));

Assert.assertEquals(StatusCodes.OK, result.getStatusLine().getStatusCode());
Assert.assertEquals(socket.getLocalAddress().getHostAddress(), map.get(GenericServletConstants.LOCAL_ADDR));
Assert.assertEquals(socket.getLocalAddress().getHostName(), map.get(GenericServletConstants.LOCAL_NAME));
Assert.assertEquals(PORT, Integer.parseInt(map.get(GenericServletConstants.LOCAL_PORT)));
Assert.assertEquals("192.0.2.10", map.get(GenericServletConstants.SERVER_NAME));
Assert.assertEquals("8888", map.get(GenericServletConstants.SERVER_PORT));
Assert.assertEquals("192.0.2.43", map.get(GenericServletConstants.REMOTE_ADDR));
Assert.assertEquals("0", map.get(GenericServletConstants.REMOTE_PORT));

} finally {
client.getConnectionManager().shutdown();
}
}

private Map<String, String> convertWithStream(String mapAsString) {
Map<String, String> map = new HashMap<String, String>();
if (mapAsString != null) {
mapAsString = mapAsString.substring(1, mapAsString.length() - 1);
map = Arrays.stream(mapAsString.split(","))
.map(entry -> entry.split("="))
.collect(Collectors.toMap(entry -> entry[0].trim(), entry -> entry[1].trim()));
}
return map;
}
}
Loading

0 comments on commit 53bc0c0

Please sign in to comment.