From a8ae3f94769fcbd6db39adb9f6bffac260ad6a39 Mon Sep 17 00:00:00 2001 From: Joakim Erdfelt Date: Mon, 24 Aug 2020 09:50:50 -0500 Subject: [PATCH] Issue #5185 - Add DoSFilter Listener to allow extensible behavior + Currently there's no way to respond to rejected/throttled/delayed requests that the DoSFilter impacts. A Listener has been added to allow for any behaviors needed by a user of the DoSFilter on requests that have been impacted by the DoSFilter. + Introducing OverLimit and RateType to DoSFilter internals Signed-off-by: Joakim Erdfelt --- .../org/eclipse/jetty/servlets/DoSFilter.java | 362 +++++++++++++----- .../eclipse/jetty/servlets/DoSFilterTest.java | 4 +- 2 files changed, 271 insertions(+), 95 deletions(-) diff --git a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java index d1ee5261212f..2ba7bd32008a 100644 --- a/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java +++ b/jetty-servlets/src/main/java/org/eclipse/jetty/servlets/DoSFilter.java @@ -20,9 +20,13 @@ import java.io.IOException; import java.io.Serializable; +import java.time.Duration; import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -161,10 +165,13 @@ public class DoSFilter implements Filter static final String ENABLED_INIT_PARAM = "enabled"; static final String TOO_MANY_CODE = "tooManyCode"; - private static final int USER_AUTH = 2; - private static final int USER_SESSION = 2; - private static final int USER_IP = 1; - private static final int USER_UNKNOWN = 0; + public enum RateType + { + AUTH, + SESSION, + IP, + UNKNOWN + } private final String _suspended = "DoSFilter@" + Integer.toHexString(hashCode()) + ".SUSPENDED"; private final String _resumed = "DoSFilter@" + Integer.toHexString(hashCode()) + ".RESUMED"; @@ -181,23 +188,22 @@ public class DoSFilter implements Filter private volatile boolean _remotePort; private volatile boolean _enabled; private volatile String _name; + private DoSFilter.Listener _listener = new Listener(); private Semaphore _passes; private volatile int _throttledRequests; private volatile int _maxRequestsPerSec; - private Queue[] _queues; - private AsyncListener[] _listeners; + private Map> _queues = new HashMap<>(); + private Map _listeners = new HashMap<>(); private Scheduler _scheduler; private ServletContext _context; @Override public void init(FilterConfig filterConfig) throws ServletException { - _queues = new Queue[getMaxPriority() + 1]; - _listeners = new AsyncListener[_queues.length]; - for (int p = 0; p < _queues.length; p++) + for (RateType rateType : RateType.values()) { - _queues[p] = new ConcurrentLinkedQueue<>(); - _listeners[p] = new DoSAsyncListener(p); + _queues.put(rateType, new ConcurrentLinkedQueue<>()); + _listeners.put(rateType, new DoSAsyncListener(rateType)); } _rateTrackers.clear(); @@ -305,67 +311,76 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response // Look for the rate tracker for this request. RateTracker tracker = (RateTracker)request.getAttribute(__TRACKER); - if (tracker == null) + if (tracker != null) + { + // Redispatched, RateTracker present in request attributes. + throttleRequest(request, response, filterChain, tracker); + return; + } + + // This is the first time we have seen this request. + if (LOG.isDebugEnabled()) + LOG.debug("Filtering {}", request); + + // Get a rate tracker associated with this request, and record one hit. + tracker = getRateTracker(request); + + // Calculate the rate and check if it is over the allowed limit + final OverLimit overLimit = tracker.isRateExceeded(System.currentTimeMillis()); + + // Pass it through if we are not currently over the rate limit. + if (overLimit == null) { - // This is the first time we have seen this request. if (LOG.isDebugEnabled()) - LOG.debug("Filtering {}", request); + LOG.debug("Allowing {}", request); + doFilterChain(filterChain, request, response); + return; + } - // Get a rate tracker associated with this request, and record one hit. - tracker = getRateTracker(request); + // We are over the limit. - // Calculate the rate and check if it is over the allowed limit - final boolean overRateLimit = tracker.isRateExceeded(System.currentTimeMillis()); + // Ask listener what to perform. + Action action = _listener.onRequestOverLimit(request, overLimit, this); - // Pass it through if we are not currently over the rate limit. - if (!overRateLimit) - { + // Perform action + long delayMs = getDelayMs(); + boolean insertHeaders = isInsertHeaders(); + switch (action) + { + case NO_ACTION: if (LOG.isDebugEnabled()) - LOG.debug("Allowing {}", request); + LOG.debug("Allowing over-limit request {}", request); doFilterChain(filterChain, request, response); + break; + case ABORT: + if (LOG.isDebugEnabled()) + LOG.debug("Aborting over-limit request {}", request); + response.sendError(-1); return; - } - - // We are over the limit. - - // So either reject it, delay it or throttle it. - long delayMs = getDelayMs(); - boolean insertHeaders = isInsertHeaders(); - switch ((int)delayMs) - { - case -1: - { - // Reject this request. - LOG.warn("DOS ALERT: Request rejected ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal()); - if (insertHeaders) - response.addHeader("DoSFilter", "unavailable"); - response.sendError(getTooManyCode()); - return; - } - case 0: - { - // Fall through to throttle the request. - LOG.warn("DOS ALERT: Request throttled ip={}, session={}, user={}", request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal()); - request.setAttribute(__TRACKER, tracker); - break; - } - default: - { - // Insert a delay before throttling the request, - // using the suspend+timeout mechanism of AsyncContext. - LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, session={}, user={}", delayMs, request.getRemoteAddr(), request.getRequestedSessionId(), request.getUserPrincipal()); - if (insertHeaders) - response.addHeader("DoSFilter", "delayed"); - request.setAttribute(__TRACKER, tracker); - AsyncContext asyncContext = request.startAsync(); - if (delayMs > 0) - asyncContext.setTimeout(delayMs); - asyncContext.addListener(new DoSTimeoutAsyncListener()); - return; - } - } + case REJECT: + if (insertHeaders) + response.addHeader("DoSFilter", "unavailable"); + response.sendError(getTooManyCode()); + return; + case DELAY: + // Insert a delay before throttling the request, + // using the suspend+timeout mechanism of AsyncContext. + if (insertHeaders) + response.addHeader("DoSFilter", "delayed"); + request.setAttribute(__TRACKER, tracker); + AsyncContext asyncContext = request.startAsync(); + if (delayMs > 0) + asyncContext.setTimeout(delayMs); + asyncContext.addListener(new DoSTimeoutAsyncListener()); + break; + case THROTTLE: + throttleRequest(request, response, filterChain, tracker); + break; } + } + private void throttleRequest(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain, RateTracker tracker) throws IOException, ServletException + { if (LOG.isDebugEnabled()) LOG.debug("Throttling {}", request); @@ -383,15 +398,15 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response long throttleMs = getThrottleMs(); if (!Boolean.TRUE.equals(throttled) && throttleMs > 0) { - int priority = getPriority(request, tracker); + RateType priority = getPriority(request, tracker); request.setAttribute(__THROTTLED, Boolean.TRUE); if (isInsertHeaders()) response.addHeader("DoSFilter", "throttled"); AsyncContext asyncContext = request.startAsync(); request.setAttribute(_suspended, Boolean.TRUE); asyncContext.setTimeout(throttleMs); - asyncContext.addListener(_listeners[priority]); - _queues[priority].add(asyncContext); + asyncContext.addListener(_listeners.get(priority)); + _queues.get(priority).add(asyncContext); if (LOG.isDebugEnabled()) LOG.debug("Throttled {}, {}ms", request, throttleMs); return; @@ -436,9 +451,9 @@ protected void doFilter(HttpServletRequest request, HttpServletResponse response try { // Wake up the next highest priority request. - for (int p = _queues.length - 1; p >= 0; --p) + for (RateType rateType : RateType.values()) { - AsyncContext asyncContext = _queues[p].poll(); + AsyncContext asyncContext = _queues.get(rateType).poll(); if (asyncContext != null) { ServletRequest candidate = asyncContext.getRequest(); @@ -530,21 +545,31 @@ protected void closeConnection(HttpServletRequest request, HttpServletResponse r * @param tracker the rate tracker for this request * @return the priority for this request */ - private int getPriority(HttpServletRequest request, RateTracker tracker) + private RateType getPriority(HttpServletRequest request, RateTracker tracker) { if (extractUserId(request) != null) - return USER_AUTH; + return RateType.AUTH; if (tracker != null) return tracker.getType(); - return USER_UNKNOWN; + return RateType.UNKNOWN; } /** * @return the maximum priority that we can assign to a request */ - protected int getMaxPriority() + protected RateType getMaxPriority() + { + return RateType.AUTH; + } + + public void setListener(DoSFilter.Listener listener) + { + _listener = Objects.requireNonNull(listener, "Listener may not be null"); + } + + public DoSFilter.Listener getListener() { - return USER_AUTH; + return _listener; } private void schedule(RateTracker tracker) @@ -573,22 +598,22 @@ RateTracker getRateTracker(ServletRequest request) HttpSession session = ((HttpServletRequest)request).getSession(false); String loadId = extractUserId(request); - final int type; + final RateType type; if (loadId != null) { - type = USER_AUTH; + type = RateType.AUTH; } else { if (isTrackSessions() && session != null && !session.isNew()) { loadId = session.getId(); - type = USER_SESSION; + type = RateType.SESSION; } else { loadId = isRemotePort() ? createRemotePortId(request) : request.getRemoteAddr(); - type = USER_IP; + type = RateType.IP; } } @@ -605,7 +630,7 @@ RateTracker getRateTracker(ServletRequest request) if (existing != null) tracker = existing; - if (type == USER_IP) + if (type == RateType.IP) { // USER_IP expiration from _rateTrackers is handled by the _scheduler _scheduler.schedule(tracker, getMaxIdleTrackerMs(), TimeUnit.MILLISECONDS); @@ -1062,6 +1087,11 @@ public void setEnabled(boolean enabled) _enabled = enabled; } + /** + * Status code for Rejected for too many requests. + * + * @return the configured status code (default: 429 - Too Many Requests) + */ public int getTooManyCode() { return _tooManyCode; @@ -1150,6 +1180,13 @@ public boolean removeWhitelistAddress(@Name("address") String address) return _whitelist.remove(address); } + private String createRemotePortId(ServletRequest request) + { + String addr = request.getRemoteAddr(); + int port = request.getRemotePort(); + return addr + ":" + port; + } + /** * A RateTracker is associated with a connection, and stores request rate * data. @@ -1161,17 +1198,19 @@ static class RateTracker implements Runnable, HttpSessionBindingListener, HttpSe protected final String _filterName; protected transient ServletContext _context; protected final String _id; - protected final int _type; + protected final RateType _type; + protected final int _maxRequestsPerSecond; protected final long[] _timestamps; protected int _next; - public RateTracker(ServletContext context, String filterName, String id, int type, int maxRequestsPerSecond) + public RateTracker(ServletContext context, String filterName, String id, RateType type, int maxRequestsPerSecond) { _context = context; _filterName = filterName; _id = id; _type = type; + _maxRequestsPerSecond = maxRequestsPerSecond; _timestamps = new long[maxRequestsPerSecond]; _next = 0; } @@ -1180,7 +1219,7 @@ public RateTracker(ServletContext context, String filterName, String id, int typ * @param now the time now (in milliseconds) * @return the current calculated request rate over the last second */ - public boolean isRateExceeded(long now) + public OverLimit isRateExceeded(long now) { final long last; synchronized (this) @@ -1190,7 +1229,17 @@ public boolean isRateExceeded(long now) _next = (_next + 1) % _timestamps.length; } - return last != 0 && (now - last) < 1000L; + if (last == 0) + { + return null; + } + + long rate = (now - last); + if (rate < 1000L) + { + return new Overage(Duration.ofMillis(rate), _maxRequestsPerSecond); + } + return null; } public String getId() @@ -1198,7 +1247,7 @@ public String getId() return _id; } - public int getType() + public RateType getType() { return _type; } @@ -1271,7 +1320,7 @@ public void run() { if (_context == null) { - LOG.warn("Unknkown context for rate tracker {}", this); + LOG.warn("Unknown context for rate tracker {}", this); return; } @@ -1297,17 +1346,66 @@ public String toString() { return "RateTracker/" + _id + "/" + _type; } + + public class Overage implements OverLimit + { + private final Duration duration; + private final long count; + + public Overage(Duration dur, long count) + { + this.duration = dur; + this.count = count; + } + + @Override + public RateType getRateType() + { + return _type; + } + + @Override + public String getRateId() + { + return _id; + } + + @Override + public Duration getDuration() + { + return duration; + } + + @Override + public long getCount() + { + return count; + } + + @Override + public String toString() + { + final StringBuilder sb = new StringBuilder(OverLimit.class.getSimpleName()); + sb.append('@').append(Integer.toHexString(hashCode())); + sb.append("[type=").append(getRateType()); + sb.append(", id=").append(getRateId()); + sb.append(", duration=").append(duration); + sb.append(", count=").append(count); + sb.append(']'); + return sb.toString(); + } + } } private static class FixedRateTracker extends RateTracker { - public FixedRateTracker(ServletContext context, String filterName, String id, int type, int numRecentRequestsTracked) + public FixedRateTracker(ServletContext context, String filterName, String id, RateType type, int numRecentRequestsTracked) { super(context, filterName, id, type, numRecentRequestsTracked); } @Override - public boolean isRateExceeded(long now) + public OverLimit isRateExceeded(long now) { // rate limit is never exceeded, but we keep track of the request timestamps // so that we know whether there was recent activity on this tracker @@ -1318,7 +1416,7 @@ public boolean isRateExceeded(long now) _next = (_next + 1) % _timestamps.length; } - return false; + return null; } @Override @@ -1354,9 +1452,9 @@ public void onError(AsyncEvent event) private class DoSAsyncListener extends DoSTimeoutAsyncListener { - private final int priority; + private final RateType priority; - public DoSAsyncListener(int priority) + public DoSAsyncListener(RateType priority) { this.priority = priority; } @@ -1364,15 +1462,93 @@ public DoSAsyncListener(int priority) @Override public void onTimeout(AsyncEvent event) throws IOException { - _queues[priority].remove(event.getAsyncContext()); + _queues.get(priority).remove(event.getAsyncContext()); super.onTimeout(event); } } - private String createRemotePortId(ServletRequest request) + public enum Action { - String addr = request.getRemoteAddr(); - int port = request.getRemotePort(); - return addr + ":" + port; + /** + * No action is taken against the Request, it is allowed to be processed normally. + */ + NO_ACTION, + /** + * The request and response is aborted, no response is sent. + */ + ABORT, + /** + * The request is rejected by sending an error based on {@link DoSFilter#getTooManyCode()} + */ + REJECT, + /** + * The request is delayed based on {@link DoSFilter#getDelayMs()} + */ + DELAY, + /** + * The request is throttled. + */ + THROTTLE; + + /** + * Obtain the Action based on configured {@link DoSFilter#getDelayMs()} + * + * @param delayMs the delay in milliseconds. + * @return the Action proposed. + */ + public static Action fromDelay(long delayMs) + { + if (delayMs < 0) + return Action.REJECT; + + if (delayMs == 0) + return Action.THROTTLE; + + return Action.DELAY; + } + } + + public interface OverLimit + { + RateType getRateType(); + + String getRateId(); + + Duration getDuration(); + + long getCount(); + } + + /** + * Listener for actions taken against specific requests. + */ + public static class Listener + { + /** + * Process the onRequestOverLimit() behavior. + * + * @param request the request that is over the limit + * @param dosFilter the {@link DoSFilter} that this event occurred on + * @return the action to actually perform. + */ + public Action onRequestOverLimit(HttpServletRequest request, OverLimit overlimit, DoSFilter dosFilter) + { + Action action = Action.fromDelay(dosFilter.getDelayMs()); + + switch (action) + { + case REJECT: + LOG.warn("DOS ALERT: Request rejected ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + break; + case DELAY: + LOG.warn("DOS ALERT: Request delayed={}ms, ip={}, overlimit={}, session={}, user={}", dosFilter.getDelayMs(), request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + break; + case THROTTLE: + LOG.warn("DOS ALERT: Request throttled ip={}, overlimit={}, session={}, user={}", request.getRemoteAddr(), overlimit, request.getRequestedSessionId(), request.getUserPrincipal()); + break; + } + + return action; + } } } diff --git a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java index 56d2f2aa504e..953e208a0d92 100644 --- a/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java +++ b/jetty-servlets/src/test/java/org/eclipse/jetty/servlets/DoSFilterTest.java @@ -174,12 +174,12 @@ private boolean hitRateTracker(DoSFilter doSFilter, int sleep) throws Interrupte { boolean exceeded = false; ServletContext context = new ContextHandler.StaticContext(); - RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", 0, 4); + RateTracker rateTracker = new RateTracker(context, doSFilter.getName(), "test2", DoSFilter.RateType.UNKNOWN, 4); for (int i = 0; i < 5; i++) { Thread.sleep(sleep); - if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime()))) + if (rateTracker.isRateExceeded(TimeUnit.NANOSECONDS.toMillis(System.nanoTime())) != null) exceeded = true; } return exceeded;