From e3155c0ce24e3acebd8e757613d01f018d050d19 Mon Sep 17 00:00:00 2001 From: Ludovic Orban Date: Mon, 26 Aug 2024 16:52:53 +0200 Subject: [PATCH] Improve ThreadLimitHandler Signed-off-by: Ludovic Orban --- jetty-server/pom.xml | 5 ++ .../server/handler/ThreadLimitHandler.java | 72 ++++++++++++++++--- .../handler/ThreadLimitHandlerTest.java | 25 +++++-- 3 files changed, 87 insertions(+), 15 deletions(-) diff --git a/jetty-server/pom.xml b/jetty-server/pom.xml index 79261a9d1927..62fd6523cbc0 100644 --- a/jetty-server/pom.xml +++ b/jetty-server/pom.xml @@ -98,5 +98,10 @@ ${project.version} test + + org.awaitility + awaitility + test + diff --git a/jetty-server/src/main/java/org/eclipse/jetty/server/handler/ThreadLimitHandler.java b/jetty-server/src/main/java/org/eclipse/jetty/server/handler/ThreadLimitHandler.java index 9829cbf6a70f..97429a89e7fa 100644 --- a/jetty-server/src/main/java/org/eclipse/jetty/server/handler/ThreadLimitHandler.java +++ b/jetty-server/src/main/java/org/eclipse/jetty/server/handler/ThreadLimitHandler.java @@ -26,10 +26,12 @@ import java.util.Deque; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import javax.servlet.AsyncContext; import javax.servlet.ServletException; +import javax.servlet.ServletRequestEvent; +import javax.servlet.ServletRequestListener; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -79,7 +81,7 @@ public class ThreadLimitHandler extends HandlerWrapper private final boolean _rfc7239; private final String _forwardedHeader; private final IncludeExcludeSet _includeExcludeSet = new IncludeExcludeSet<>(InetAddressSet.class); - private final ConcurrentMap _remotes = new ConcurrentHashMap<>(); + private final ConcurrentHashMap _remotes = new ConcurrentHashMap<>(); private volatile boolean _enabled; private int _threadLimit = 10; @@ -165,6 +167,22 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques } else { + request.getServletContext().addListener(new ServletRequestListener() + { + @Override + public void requestInitialized(ServletRequestEvent servletRequestEvent) + { + } + + @Override + public void requestDestroyed(ServletRequestEvent sre) + { + // Use a compute method to remove the Remote instance as it is necessary for + // the ref counter release and the removal to be atomic. + _remotes.computeIfPresent(remote._ip, (k, v) -> v._referenceCounter.release() ? null : v); + } + }); + // Do we already have a future permit from a previous invocation? Closeable permit = (Closeable)baseRequest.getAttribute(PERMIT); try @@ -256,14 +274,18 @@ protected Remote getRemote(Request baseRequest) if (limit <= 0) return null; - remote = _remotes.get(ip); - if (remote == null) + // Use a compute method to create or retain the Remote instance as it is necessary for + // the ref counter increment or the instance creation to be mutually exclusive. + // The map MUST be a CHM as it guarantees the remapping function is only called once. + remote = _remotes.compute(ip, (k, v) -> { - Remote r = new Remote(ip, limit); - remote = _remotes.putIfAbsent(ip, r); - if (remote == null) - remote = r; - } + if (v != null) + { + v._referenceCounter.retain(); + return v; + } + return new Remote(k, limit); + }); baseRequest.setAttribute(REMOTE, remote); @@ -282,7 +304,7 @@ protected String getRemoteIP(Request baseRequest) } // If no remote IP from a header, determine it directly from the channel - // Do not use the request methods, as they may have been lied to by the + // Do not use the request methods, as they may have been lied to by the // RequestCustomizer! InetSocketAddress inetAddr = baseRequest.getHttpChannel().getRemoteAddress(); if (inetAddr != null && inetAddr.getAddress() != null) @@ -329,11 +351,17 @@ private String getXForwardedFor(Request request) return (comma >= 0) ? forwardedFor.substring(comma + 1).trim() : forwardedFor; } + int getRemoteCount() + { + return _remotes.size(); + } + private final class Remote implements Closeable { private final String _ip; private final int _limit; private final Locker _locker = new Locker(); + private final ReferenceCounter _referenceCounter = new ReferenceCounter(); private int _permits; private Deque> _queue = new ArrayDeque<>(); private final CompletableFuture _permitted = CompletableFuture.completedFuture(this); @@ -357,7 +385,7 @@ public CompletableFuture acquire() return _permitted; // TODO is it OK to share/reuse this? } - // No pass available, so queue a new future + // No pass available, so queue a new future CompletableFuture pass = new CompletableFuture(); _queue.addLast(pass); return pass; @@ -437,4 +465,26 @@ protected void parsedParam(StringBuffer buffer, int valueLength, int paramName, } } } + + private static class ReferenceCounter + { + private final AtomicInteger references = new AtomicInteger(1); + + public void retain() + { + if (references.getAndUpdate(c -> c == 0 ? 0 : c + 1) == 0) + throw new IllegalStateException("released " + this); + } + + public boolean release() + { + int ref = references.updateAndGet(c -> + { + if (c == 0) + throw new IllegalStateException("already released " + this); + return c - 1; + }); + return ref == 0; + } + } } diff --git a/jetty-server/src/test/java/org/eclipse/jetty/server/handler/ThreadLimitHandlerTest.java b/jetty-server/src/test/java/org/eclipse/jetty/server/handler/ThreadLimitHandlerTest.java index 9e0119755d43..fde8c43bde9b 100644 --- a/jetty-server/src/test/java/org/eclipse/jetty/server/handler/ThreadLimitHandlerTest.java +++ b/jetty-server/src/test/java/org/eclipse/jetty/server/handler/ThreadLimitHandlerTest.java @@ -40,6 +40,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import static org.awaitility.Awaitility.await; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -88,7 +89,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques response.setStatus(HttpStatus.OK_200); } }); - _server.setHandler(handler); + ContextHandler contextHandler = new ContextHandler("/"); + contextHandler.setHandler(handler); + _server.setHandler(contextHandler); _server.start(); last.set(null); @@ -102,6 +105,8 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques last.set(null); _local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n"); assertThat(last.get(), Matchers.is("0.0.0.0")); + + await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0)); } @Test @@ -117,7 +122,9 @@ protected int getThreadLimit(String ip) return super.getThreadLimit(ip); } }; - _server.setHandler(handler); + ContextHandler contextHandler = new ContextHandler("/"); + contextHandler.setHandler(handler); + _server.setHandler(contextHandler); _server.start(); last.set(null); @@ -135,6 +142,8 @@ protected int getThreadLimit(String ip) last.set(null); _local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nX-Forwarded-For: 6.6.6.6,1.2.3.4\r\nForwarded: for=1.2.3.4\r\n\r\n"); assertThat(last.get(), Matchers.is("1.2.3.4")); + + await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0)); } @Test @@ -150,7 +159,9 @@ protected int getThreadLimit(String ip) return super.getThreadLimit(ip); } }; - _server.setHandler(handler); + ContextHandler contextHandler = new ContextHandler("/"); + contextHandler.setHandler(handler); + _server.setHandler(contextHandler); _server.start(); last.set(null); @@ -168,6 +179,8 @@ protected int getThreadLimit(String ip) last.set(null); _local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nForwarded: for=6.6.6.6; for=1.2.3.4\r\nX-Forwarded-For: 6.6.6.6\r\nForwarded: proto=https\r\n\r\n"); assertThat(last.get(), Matchers.is("1.2.3.4")); + + await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0)); } @Test @@ -206,7 +219,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques } } }); - _server.setHandler(handler); + ContextHandler contextHandler = new ContextHandler("/"); + contextHandler.setHandler(handler); + _server.setHandler(contextHandler); _server.start(); Socket[] client = new Socket[10]; @@ -241,5 +256,7 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques Thread.sleep(10); } assertThat(count.get(), is(0)); + + await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0)); } }