Skip to content

Commit

Permalink
Improve ThreadLimitHandler
Browse files Browse the repository at this point in the history
Signed-off-by: Ludovic Orban <lorban@bitronix.be>
  • Loading branch information
lorban committed Aug 26, 2024
1 parent dd6c9a2 commit e3155c0
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 15 deletions.
5 changes: 5 additions & 0 deletions jetty-server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,10 @@
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import java.util.Deque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.AsyncContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequestEvent;
import javax.servlet.ServletRequestListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

Expand Down Expand Up @@ -79,7 +81,7 @@ public class ThreadLimitHandler extends HandlerWrapper
private final boolean _rfc7239;
private final String _forwardedHeader;
private final IncludeExcludeSet<String, InetAddress> _includeExcludeSet = new IncludeExcludeSet<>(InetAddressSet.class);
private final ConcurrentMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private volatile boolean _enabled;
private int _threadLimit = 10;

Expand Down Expand Up @@ -165,6 +167,22 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
}
else
{
request.getServletContext().addListener(new ServletRequestListener()
{
@Override
public void requestInitialized(ServletRequestEvent servletRequestEvent)
{
}

@Override
public void requestDestroyed(ServletRequestEvent sre)
{
// Use a compute method to remove the Remote instance as it is necessary for
// the ref counter release and the removal to be atomic.
_remotes.computeIfPresent(remote._ip, (k, v) -> v._referenceCounter.release() ? null : v);
}
});

// Do we already have a future permit from a previous invocation?
Closeable permit = (Closeable)baseRequest.getAttribute(PERMIT);
try
Expand Down Expand Up @@ -256,14 +274,18 @@ protected Remote getRemote(Request baseRequest)
if (limit <= 0)
return null;

remote = _remotes.get(ip);
if (remote == null)
// Use a compute method to create or retain the Remote instance as it is necessary for
// the ref counter increment or the instance creation to be mutually exclusive.
// The map MUST be a CHM as it guarantees the remapping function is only called once.
remote = _remotes.compute(ip, (k, v) ->
{
Remote r = new Remote(ip, limit);
remote = _remotes.putIfAbsent(ip, r);
if (remote == null)
remote = r;
}
if (v != null)
{
v._referenceCounter.retain();
return v;
}
return new Remote(k, limit);
});

baseRequest.setAttribute(REMOTE, remote);

Expand All @@ -282,7 +304,7 @@ protected String getRemoteIP(Request baseRequest)
}

// If no remote IP from a header, determine it directly from the channel
// Do not use the request methods, as they may have been lied to by the
// Do not use the request methods, as they may have been lied to by the
// RequestCustomizer!
InetSocketAddress inetAddr = baseRequest.getHttpChannel().getRemoteAddress();
if (inetAddr != null && inetAddr.getAddress() != null)
Expand Down Expand Up @@ -329,11 +351,17 @@ private String getXForwardedFor(Request request)
return (comma >= 0) ? forwardedFor.substring(comma + 1).trim() : forwardedFor;
}

int getRemoteCount()
{
return _remotes.size();
}

private final class Remote implements Closeable
{
private final String _ip;
private final int _limit;
private final Locker _locker = new Locker();
private final ReferenceCounter _referenceCounter = new ReferenceCounter();
private int _permits;
private Deque<CompletableFuture<Closeable>> _queue = new ArrayDeque<>();
private final CompletableFuture<Closeable> _permitted = CompletableFuture.completedFuture(this);
Expand All @@ -357,7 +385,7 @@ public CompletableFuture<Closeable> acquire()
return _permitted; // TODO is it OK to share/reuse this?
}

// No pass available, so queue a new future
// No pass available, so queue a new future
CompletableFuture<Closeable> pass = new CompletableFuture<Closeable>();
_queue.addLast(pass);
return pass;
Expand Down Expand Up @@ -437,4 +465,26 @@ protected void parsedParam(StringBuffer buffer, int valueLength, int paramName,
}
}
}

private static class ReferenceCounter
{
private final AtomicInteger references = new AtomicInteger(1);

public void retain()
{
if (references.getAndUpdate(c -> c == 0 ? 0 : c + 1) == 0)
throw new IllegalStateException("released " + this);
}

public boolean release()
{
int ref = references.updateAndGet(c ->
{
if (c == 0)
throw new IllegalStateException("already released " + this);
return c - 1;
});
return ref == 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

Expand Down Expand Up @@ -88,7 +89,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
response.setStatus(HttpStatus.OK_200);
}
});
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -102,6 +105,8 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), Matchers.is("0.0.0.0"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand All @@ -117,7 +122,9 @@ protected int getThreadLimit(String ip)
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -135,6 +142,8 @@ protected int getThreadLimit(String ip)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nX-Forwarded-For: 6.6.6.6,1.2.3.4\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), Matchers.is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand All @@ -150,7 +159,9 @@ protected int getThreadLimit(String ip)
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -168,6 +179,8 @@ protected int getThreadLimit(String ip)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nForwarded: for=6.6.6.6; for=1.2.3.4\r\nX-Forwarded-For: 6.6.6.6\r\nForwarded: proto=https\r\n\r\n");
assertThat(last.get(), Matchers.is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -206,7 +219,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
}
}
});
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

Socket[] client = new Socket[10];
Expand Down Expand Up @@ -241,5 +256,7 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
Thread.sleep(10);
}
assertThat(count.get(), is(0));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}
}

0 comments on commit e3155c0

Please sign in to comment.