Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jetty 10.0.x: Improve ThreadLimitHandler #12201

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import java.util.Deque;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.servlet.AsyncContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequestEvent;
import javax.servlet.ServletRequestListener;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

Expand Down Expand Up @@ -73,7 +75,7 @@ public class ThreadLimitHandler extends HandlerWrapper
private final boolean _rfc7239;
private final String _forwardedHeader;
private final IncludeExcludeSet<String, InetAddress> _includeExcludeSet = new IncludeExcludeSet<>(InetAddressSet.class);
private final ConcurrentMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, Remote> _remotes = new ConcurrentHashMap<>();
private volatile boolean _enabled;
private int _threadLimit = 10;

Expand Down Expand Up @@ -179,6 +181,17 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
}
else
{
request.getServletContext().addListener(new ServletRequestListener()
{
@Override
public void requestDestroyed(ServletRequestEvent sre)
{
// Use a compute method to remove the Remote instance as it is necessary for
// the ref counter release and the removal to be atomic.
_remotes.computeIfPresent(remote._ip, (k, v) -> v._referenceCounter.release() ? null : v);
}
});

// Do we already have a future permit from a previous invocation?
Closeable permit = (Closeable)baseRequest.getAttribute(PERMIT);
try
Expand Down Expand Up @@ -250,14 +263,18 @@ private Remote getRemote(Request baseRequest)
if (limit <= 0)
return null;

remote = _remotes.get(ip);
if (remote == null)
// Use a compute method to create or retain the Remote instance as it is necessary for
// the ref counter increment or the instance creation to be mutually exclusive.
// The map MUST be a CHM as it guarantees the remapping function is only called once.
remote = _remotes.compute(ip, (k, v) ->
{
Remote r = new Remote(ip, limit);
remote = _remotes.putIfAbsent(ip, r);
if (remote == null)
remote = r;
}
if (v != null)
{
v._referenceCounter.retain();
return v;
}
return new Remote(k, limit);
});

baseRequest.setAttribute(REMOTE, remote);

Expand All @@ -276,7 +293,7 @@ protected String getRemoteIP(Request baseRequest)
}

// If no remote IP from a header, determine it directly from the channel
// Do not use the request methods, as they may have been lied to by the
// Do not use the request methods, as they may have been lied to by the
// RequestCustomizer!
InetSocketAddress inetAddr = baseRequest.getHttpChannel().getRemoteAddress();
if (inetAddr != null && inetAddr.getAddress() != null)
Expand Down Expand Up @@ -321,11 +338,17 @@ private String getXForwardedFor(Request request)
return (comma >= 0) ? forwardedFor.substring(comma + 1).trim() : forwardedFor;
}

int getRemoteCount()
{
return _remotes.size();
}

private static final class Remote implements Closeable
{
private final String _ip;
private final int _limit;
private final AutoLock _lock = new AutoLock();
private final ReferenceCounter _referenceCounter = new ReferenceCounter();
private int _permits;
private Deque<CompletableFuture<Closeable>> _queue = new ArrayDeque<>();
private final CompletableFuture<Closeable> _permitted = CompletableFuture.completedFuture(this);
Expand All @@ -349,7 +372,7 @@ public CompletableFuture<Closeable> acquire()
return _permitted; // TODO is it OK to share/reuse this?
}

// No pass available, so queue a new future
// No pass available, so queue a new future
CompletableFuture<Closeable> pass = new CompletableFuture<>();
_queue.addLast(pass);
return pass;
Expand Down Expand Up @@ -429,4 +452,26 @@ protected void parsedParam(StringBuffer buffer, int valueLength, int paramName,
}
}
}

private static class ReferenceCounter
{
private final AtomicInteger references = new AtomicInteger(1);

public void retain()
{
if (references.getAndUpdate(c -> c == 0 ? 0 : c + 1) == 0)
throw new IllegalStateException("released " + this);
}

public boolean release()
{
int ref = references.updateAndGet(c ->
{
if (c == 0)
throw new IllegalStateException("already released " + this);
return c - 1;
});
return ref == 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.io.IOException;
import java.net.Socket;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import javax.servlet.ServletException;
Expand All @@ -35,6 +36,7 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.awaitility.Awaitility.await;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

Expand Down Expand Up @@ -83,7 +85,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
response.setStatus(HttpStatus.OK_200);
}
});
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -97,6 +101,8 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), Matchers.is("0.0.0.0"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand All @@ -112,7 +118,9 @@ protected int getThreadLimit(String ip)
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -130,6 +138,8 @@ protected int getThreadLimit(String ip)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nX-Forwarded-For: 6.6.6.6,1.2.3.4\r\nForwarded: for=1.2.3.4\r\n\r\n");
assertThat(last.get(), Matchers.is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand All @@ -145,7 +155,9 @@ protected int getThreadLimit(String ip)
return super.getThreadLimit(ip);
}
};
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

last.set(null);
Expand All @@ -163,6 +175,8 @@ protected int getThreadLimit(String ip)
last.set(null);
_local.getResponse("GET / HTTP/1.0\r\nX-Forwarded-For: 1.1.1.1\r\nForwarded: for=6.6.6.6; for=1.2.3.4\r\nX-Forwarded-For: 6.6.6.6\r\nForwarded: proto=https\r\n\r\n");
assertThat(last.get(), Matchers.is("1.2.3.4"));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}

@Test
Expand Down Expand Up @@ -201,7 +215,9 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
}
}
});
_server.setHandler(handler);
ContextHandler contextHandler = new ContextHandler("/");
contextHandler.setHandler(handler);
_server.setHandler(contextHandler);
_server.start();

Socket[] client = new Socket[10];
Expand Down Expand Up @@ -237,5 +253,7 @@ public void handle(String target, Request baseRequest, HttpServletRequest reques
Thread.sleep(10);
}
assertThat(count.get(), is(0));

await().atMost(5, TimeUnit.SECONDS).until(handler::getRemoteCount, is(0));
}
}
Loading