diff --git a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/IteratingCallback.java b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/IteratingCallback.java index 155577f58e24..bc0328edb1b1 100644 --- a/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/IteratingCallback.java +++ b/jetty-core/jetty-util/src/main/java/org/eclipse/jetty/util/IteratingCallback.java @@ -14,8 +14,12 @@ package org.eclipse.jetty.util; import java.io.IOException; +import java.util.Objects; +import java.util.function.Consumer; import org.eclipse.jetty.util.thread.AutoLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * This specialized callback implements a pattern that allows @@ -51,10 +55,12 @@ */ public abstract class IteratingCallback implements Callback { + private static final Logger LOG = LoggerFactory.getLogger(IteratingCallback.class); + /** * The internal states of this callback. */ - private enum State + enum State { /** * This callback is idle, ready to iterate. @@ -64,48 +70,35 @@ private enum State /** * This callback is just about to call {@link #process()}, * or within it, or just exited from it, either normally - * or by throwing. + * or by throwing. Further actions are waiting for the + * {@link #process()} method to return. */ PROCESSING, /** - * Method {@link #process()} returned {@link Action#SCHEDULED} - * and this callback is waiting for the asynchronous sub-task - * to complete. - */ - PENDING, - - /** - * The asynchronous sub-task was completed successfully - * via a call to {@link #succeeded()} while in - * {@link #PROCESSING} state. + * The asynchronous sub-task was completed either with + * a call to {@link #succeeded()} or {@link #failed(Throwable)}, whilst in + * {@link #PROCESSING} state. Further actions are waiting for the + * {@link #process()} method to return. */ - CALLED, + PROCESSING_CALLED, /** - * The iteration terminated successfully as indicated by - * {@link Action#SUCCEEDED} returned from - * {@link IteratingCallback#process()}. - */ - SUCCEEDED, - - /** - * The iteration terminated with a failure via a call - * to {@link IteratingCallback#failed(Throwable)}. + * Method {@link #process()} returned {@link Action#SCHEDULED} + * and this callback is waiting for the asynchronous sub-task + * to complete via a callback to {@link #succeeded()} or {@link #failed(Throwable)} */ - FAILED, + PENDING, /** - * This callback has been {@link #close() closed} and - * cannot be {@link #reset() reset}. + * This callback is complete. */ - CLOSED, + COMPLETE, /** - * This callback has been {@link #abort(Throwable) aborted}, - * and cannot be {@link #reset() reset}. + * Complete and can't be reset. */ - ABORTED + CLOSED } /** @@ -120,6 +113,7 @@ protected enum Action * for additional events to trigger more work. */ IDLE, + /** * Indicates that {@link #process()} has initiated an asynchronous * sub-task, where the execution has started but the callback @@ -127,6 +121,7 @@ protected enum Action * may have not yet been invoked. */ SCHEDULED, + /** * Indicates that {@link #process()} has completed the whole * iteration successfully. @@ -135,9 +130,13 @@ protected enum Action } private final AutoLock _lock = new AutoLock(); + private final Runnable _onSuccess = this::onSuccess; + private final Runnable _processing = this::processing; + private final Consumer _onCompleted = this::onCompleted; private State _state; private Throwable _failure; - private boolean _iterate; + private boolean _reprocess; + private boolean _aborted; protected IteratingCallback() { @@ -146,7 +145,7 @@ protected IteratingCallback() protected IteratingCallback(boolean needReset) { - _state = needReset ? State.SUCCEEDED : State.IDLE; + _state = needReset ? State.COMPLETE : State.IDLE; } /** @@ -179,8 +178,24 @@ protected void onSuccess() { } + /** + * Invoked when the overall task has been {@link #abort(Throwable) aborted} or {@link #failed(Throwable) failed}. + *

+ * Calls to this method are serialized with respect to {@link #onAborted(Throwable)}, {@link #process()}, + * {@link #onCompleteFailure(Throwable)} and {@link #onCompleted(Throwable)}. + * + * @param cause The cause of the failure or abort + */ + protected void onFailure(Throwable cause) + { + } + /** * Invoked when the overall task has completed successfully. + *

+ * Calls to this method are serialized with respect to {@link #process()}, {@link #onAborted(Throwable)} + * and {@link #onCompleted(Throwable)}. + * If this method is called, then {@link #onCompleteFailure(Throwable)} ()} will never be called. * * @see #onCompleteFailure(Throwable) */ @@ -190,6 +205,10 @@ protected void onCompleteSuccess() /** * Invoked when the overall task has completed with a failure. + *

+ * Calls to this method are serialized with respect to {@link #process()}, {@link #onAborted(Throwable)} + * and {@link #onCompleted(Throwable)}. + * If this method is called, then {@link #onCompleteSuccess()} will never be called. * * @param cause the throwable to indicate cause of failure * @see #onCompleteSuccess() @@ -198,6 +217,93 @@ protected void onCompleteFailure(Throwable cause) { } + /** + * Invoked when the overall task has been aborted. + *

+ * Calls to this method are serialized with respect to {@link #process()}, {@link #onCompleteFailure(Throwable)} + * and {@link #onCompleted(Throwable)}. + * If this method is called, then {@link #onCompleteSuccess()} will never be called. + *

+ * The default implementation of this method calls {@link #failed(Throwable)}. Overridden implementations of + * this method SHOULD NOT call {@code super.onAborted(Throwable)}. + * + * @param cause The cause of the abort + */ + protected void onAborted(Throwable cause) + { + } + + /** + * Invoked when the overall task has completed. + *

+ * Calls to this method are serialized with respect to {@link #process()} and {@link #onAborted(Throwable)}. + * The default implementation of this method will call either {@link #onCompleteSuccess()} or {@link #onCompleteFailure(Throwable)} + * thus implementations of this method should always call {@code super.onCompleted(Throwable)}. + * + * @param causeOrNull the cause of any {@link #abort(Throwable) abort} or {@link #failed(Throwable) failure}, + * else {@code null} for {@link #succeeded() success}. + */ + protected void onCompleted(Throwable causeOrNull) + { + if (causeOrNull == null) + onCompleteSuccess(); + else + onCompleteFailure(causeOrNull); + } + + private void doOnSuccessProcessing() + { + ExceptionUtil.callAndThen(_onSuccess, _processing); + } + + private void doCompleteSuccess() + { + onCompleted(null); + } + + private void doOnCompleted(Throwable cause) + { + ExceptionUtil.call(cause, _onCompleted); + } + + private void doOnFailureOnCompleted(Throwable cause) + { + ExceptionUtil.callAndThen(cause, this::onFailure, _onCompleted); + } + + private void doOnAbortedOnFailure(Throwable cause) + { + ExceptionUtil.callAndThen(cause, this::onAborted, this::onFailure); + } + + private void doOnAbortedOnFailureOnCompleted(Throwable cause) + { + ExceptionUtil.callAndThen(cause, this::doOnAbortedOnFailure, _onCompleted); + } + + private void doOnAbortedOnFailureIfNotPendingDoCompleted(Throwable cause) + { + ExceptionUtil.callAndThen(cause, this::doOnAbortedOnFailure, this::ifNotPendingDoCompleted); + } + + private void ifNotPendingDoCompleted() + { + Throwable completeFailure = null; + try (AutoLock ignored = _lock.lock()) + { + _failure = _failure.getCause(); + + if (Objects.requireNonNull(_state) != State.PENDING) + { + // the callback completed, one way or another, so it is up to us to do the completion + completeFailure = _failure; + } + } + + if (completeFailure != null) + doOnCompleted(completeFailure); + } + /** * This method must be invoked by applications to start the processing * of asynchronous sub-tasks. @@ -215,28 +321,18 @@ public void iterate() { switch (_state) { - case PENDING: - case CALLED: - // process will be called when callback is handled - break; - case IDLE: _state = State.PROCESSING; process = true; break; case PROCESSING: - _iterate = true; + case PROCESSING_CALLED: + _reprocess = true; break; - case FAILED: - case SUCCEEDED: - break; - - case CLOSED: - case ABORTED: default: - throw new IllegalStateException(toString()); + break; } } if (process) @@ -248,21 +344,24 @@ private void processing() // This should only ever be called when in processing state, however a failed or close call // may happen concurrently, so state is not assumed. - boolean notifyCompleteSuccess = false; - Throwable notifyCompleteFailure = null; + boolean completeSuccess = false; + Throwable onAbortedOnFailureOnCompleted = null; + Throwable onFailureOnCompleted = null; + Throwable onAbortedOnFailureIfNotPendingDoCompleted = null; // While we are processing processing: while (true) { // Call process to get the action that we have to take. - Action action = null; + Action action; try { action = process(); } catch (Throwable x) { + action = null; failed(x); // Fall through to possibly invoke onCompleteFailure(). } @@ -271,72 +370,104 @@ private void processing() // acted on the action we have just received try (AutoLock ignored = _lock.lock()) { + if (LOG.isDebugEnabled()) + LOG.debug("processing {} {}", action, this); + switch (_state) { case PROCESSING: { - if (action != null) + if (action == null) + break processing; + switch (action) { - switch (action) + case IDLE: { - case IDLE: + if (_aborted) { - // Has iterate been called while we were processing? - if (_iterate) - { - // yes, so skip idle and keep processing - _iterate = false; - continue; - } - - // No, so we can go idle - _state = State.IDLE; + _state = _failure instanceof ClosedException ? State.CLOSED : State.COMPLETE; + onAbortedOnFailureOnCompleted = _failure; break processing; } - case SCHEDULED: + + // Has iterate been called while we were processing? + if (_reprocess) { - // we won the race against the callback, so the callback has to process and we can break processing - _state = State.PENDING; - break processing; + // yes, so skip idle and keep processing + _reprocess = false; + continue; } - case SUCCEEDED: + + // No, so we can go idle + _state = State.IDLE; + break processing; + } + case SCHEDULED: + { + // we won the race against the callback, so the callback has to process and we can break processing + _state = State.PENDING; + if (_aborted) { - // we lost the race against the callback, - _iterate = false; - _state = State.SUCCEEDED; - notifyCompleteSuccess = true; - break processing; + onAbortedOnFailureIfNotPendingDoCompleted = _failure; + _failure = new AbortingException(onAbortedOnFailureIfNotPendingDoCompleted); + } + break processing; + } + case SUCCEEDED: + { + // we lost the race against the callback, + _reprocess = false; + if (_aborted) + { + _state = _failure instanceof ClosedException ? State.CLOSED : State.COMPLETE; + onAbortedOnFailureOnCompleted = _failure; } - default: + else { - break; + _state = State.COMPLETE; + completeSuccess = true; } + break processing; + } + default: + { + break; } } throw new IllegalStateException(String.format("%s[action=%s]", this, action)); } - case CALLED: + case PROCESSING_CALLED: { + if (action != Action.SCHEDULED && action != null) + { + _state = State.CLOSED; + onAbortedOnFailureOnCompleted = new IllegalStateException("Action not scheduled"); + if (_failure == null) + { + _failure = onAbortedOnFailureOnCompleted; + } + else + { + ExceptionUtil.addSuppressedIfNotAssociated(_failure, onAbortedOnFailureIfNotPendingDoCompleted); + onAbortedOnFailureOnCompleted = _failure; + } + break processing; + } + if (_failure != null) + { + if (_aborted) + onAbortedOnFailureOnCompleted = _failure; + else + onFailureOnCompleted = _failure; + _state = _failure instanceof ClosedException ? State.CLOSED : State.COMPLETE; + break processing; + } callOnSuccess = true; - if (action != Action.SCHEDULED) - throw new IllegalStateException(String.format("%s[action=%s]", this, action)); - // we lost the race, so we have to keep processing _state = State.PROCESSING; - continue; + break; } - case FAILED: - case CLOSED: - case ABORTED: - notifyCompleteFailure = _failure; - break processing; - - case SUCCEEDED: - break processing; - - case IDLE: - case PENDING: default: throw new IllegalStateException(String.format("%s[action=%s]", this, action)); } @@ -347,47 +478,74 @@ private void processing() onSuccess(); } } - - if (notifyCompleteSuccess) - onCompleteSuccess(); - else if (notifyCompleteFailure != null) - onCompleteFailure(notifyCompleteFailure); + if (onAbortedOnFailureOnCompleted != null) + doOnAbortedOnFailureOnCompleted(onAbortedOnFailureOnCompleted); + else if (completeSuccess) + doCompleteSuccess(); + else if (onFailureOnCompleted != null) + doOnFailureOnCompleted(onFailureOnCompleted); + else if (onAbortedOnFailureIfNotPendingDoCompleted != null) + doOnAbortedOnFailureIfNotPendingDoCompleted(onAbortedOnFailureIfNotPendingDoCompleted); } /** * Method to invoke when the asynchronous sub-task succeeds. *

- * This method should be considered final for all practical purposes. - *

+ * For most purposes, this method should be considered {@code final} and should only be + * overridden in extraordinary circumstances. + * Subclasses that override this method must always call {@code super.succeeded()}. + * Such overridden methods are not serialized with respect to {@link #process()}, {@link #onCompleteSuccess()}, + * {@link #onCompleteFailure(Throwable)}, nor {@link #onAborted(Throwable)}. They should not act on nor change any + * fields that may be used by those methods. * Eventually, {@link #onSuccess()} is * called, either by the caller thread or by the processing * thread. */ @Override - public void succeeded() + public final void succeeded() { - boolean process = false; + boolean onSuccessProcessing = false; + Throwable onCompleted = null; try (AutoLock ignored = _lock.lock()) { + if (LOG.isDebugEnabled()) + LOG.debug("succeeded {}", this); switch (_state) { case PROCESSING: { - _state = State.CALLED; + // Another thread is processing, so we just tell it the state + _state = State.PROCESSING_CALLED; break; } case PENDING: { - _state = State.PROCESSING; - process = true; + if (_aborted) + { + if (_failure instanceof AbortingException) + { + // Another thread is still calling onAborted, so we will let it do the completion + _state = _failure.getCause() instanceof ClosedException ? State.CLOSED : State.COMPLETE; + } + else + { + // The onAborted call is complete, so we must do the completion + _state = _failure instanceof ClosedException ? State.CLOSED : State.COMPLETE; + onCompleted = _failure; + } + } + else + { + // No other thread is processing, so we will do the processing + _state = State.PROCESSING; + onSuccessProcessing = true; + } break; } - case FAILED: - case CLOSED: - case ABORTED: + case COMPLETE, CLOSED: { - // Too late! - break; + // Too late + return; } default: { @@ -395,10 +553,13 @@ public void succeeded() } } } - if (process) + if (onSuccessProcessing) { - onSuccess(); - processing(); + doOnSuccessProcessing(); + } + else if (onCompleted != null) + { + doOnCompleted(onCompleted); } } @@ -407,47 +568,84 @@ public void succeeded() * or to fail the overall asynchronous task and therefore * terminate the iteration. *

- * This method should be considered final for all practical purposes. - *

* Eventually, {@link #onCompleteFailure(Throwable)} is * called, either by the caller thread or by the processing * thread. - * + *

+ * For most purposes, this method should be considered {@code final} and should only be + * overridden in extraordinary circumstances. + * Subclasses that override this method must always call {@code super.succeeded()}. + * Such overridden methods are not serialized with respect to {@link #process()}, {@link #onCompleteSuccess()}, + * {@link #onCompleteFailure(Throwable)}, nor {@link #onAborted(Throwable)}. They should not act on nor change any + * fields that may be used by those methods. * @see #isFailed() */ @Override - public void failed(Throwable x) + public final void failed(Throwable cause) { - boolean failure = false; + cause = Objects.requireNonNullElseGet(cause, IOException::new); + + Throwable onFailureOnCompleted = null; + Throwable onCompleted = null; try (AutoLock ignored = _lock.lock()) { + if (LOG.isDebugEnabled()) + LOG.debug("failed {}", this, cause); switch (_state) { - case CALLED: - case SUCCEEDED: - case FAILED: - case CLOSED: - case ABORTED: - // Too late! + case PROCESSING: + { + // Another thread is processing, so we just tell it the state + _state = State.PROCESSING_CALLED; + if (_failure == null) + _failure = cause; + else + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); break; + } case PENDING: { - _state = State.FAILED; - failure = true; + if (_aborted) + { + if (_failure instanceof AbortingException) + { + // Another thread is still calling onAborted, so we will let it do the completion + ExceptionUtil.addSuppressedIfNotAssociated(_failure.getCause(), cause); + _state = _failure.getCause() instanceof ClosedException ? State.CLOSED : State.COMPLETE; + } + else + { + // The onAborted call is complete, so we must do the completion + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); + _state = _failure instanceof ClosedException ? State.CLOSED : State.COMPLETE; + onCompleted = _failure; + } + } + else + { + // No other thread is processing, so we will do the processing + _state = State.COMPLETE; + _failure = cause; + onFailureOnCompleted = _failure; + } break; } - case PROCESSING: + case COMPLETE, CLOSED: { - _state = State.FAILED; - _failure = x; - break; + // Too late + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); + return; } default: + { throw new IllegalStateException(toString()); + } } } - if (failure) - onCompleteFailure(x); + if (onFailureOnCompleted != null) + doOnFailureOnCompleted(onFailureOnCompleted); + else if (onCompleted != null) + doOnCompleted(onCompleted); } /** @@ -459,37 +657,63 @@ public void failed(Throwable x) * * @see #isClosed() */ - public void close() + public final void close() { - String failure = null; + Throwable onAbortedOnFailureIfNotPendingDoCompleted = null; + Throwable onAbortOnFailureOnCompleted = null; + try (AutoLock ignored = _lock.lock()) { + if (LOG.isDebugEnabled()) + LOG.debug("close {}", this); switch (_state) { - case IDLE: - case SUCCEEDED: - case FAILED: - _state = State.CLOSED; - break; - - case PROCESSING: - _failure = new IOException(String.format("Close %s in state %s", this, _state)); + case IDLE -> + { + // Nothing happening so we can abort and complete _state = State.CLOSED; - break; + _failure = new ClosedException(); + onAbortOnFailureOnCompleted = _failure; + } + case PROCESSING, PROCESSING_CALLED -> + { + // Another thread is processing, so we just tell it the state and let it handle it + if (_aborted) + { + ExceptionUtil.addSuppressedIfNotAssociated(_failure, new ClosedException()); + } + else + { + _aborted = true; + _failure = new ClosedException(); + } + } - case CLOSED: - case ABORTED: - break; + case PENDING -> + { + // We are waiting for the callback, so we can only call onAbort and then keep waiting + onAbortedOnFailureIfNotPendingDoCompleted = new ClosedException(); + _failure = new AbortingException(onAbortedOnFailureIfNotPendingDoCompleted); + _aborted = true; + } - default: - failure = String.format("Close %s in state %s", this, _state); + case COMPLETE -> + { _state = State.CLOSED; - break; + } + + case CLOSED -> + { + // too late + return; + } } } - if (failure != null) - onCompleteFailure(new IOException(failure)); + if (onAbortedOnFailureIfNotPendingDoCompleted != null) + doOnAbortedOnFailureIfNotPendingDoCompleted(onAbortedOnFailureIfNotPendingDoCompleted); + else if (onAbortOnFailureOnCompleted != null) + doOnAbortedOnFailureOnCompleted(onAbortOnFailureOnCompleted); } /** @@ -498,49 +722,83 @@ public void close() * ultimately be invoked, either during this call or later after * any call to {@link #process()} has returned.

* - * @param failure the cause of the abort + * @param cause the cause of the abort + * @return {@code true} if abort was called before the callback was complete. * @see #isAborted() */ - public void abort(Throwable failure) + public final boolean abort(Throwable cause) { - boolean abort = false; + cause = Objects.requireNonNullElseGet(cause, Throwable::new); + + boolean onAbort = false; + boolean onAbortDoCompleteFailure = false; try (AutoLock ignored = _lock.lock()) { + if (LOG.isDebugEnabled()) + LOG.debug("abort {}", this, cause); + + // Are we already aborted? + if (_aborted) + { + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); + return false; + } + switch (_state) { - case SUCCEEDED: - case FAILED: - case CLOSED: - case ABORTED: + case IDLE: { - // Too late. + // Nothing happening so we can abort and complete + _state = State.COMPLETE; + _failure = cause; + _aborted = true; + onAbortDoCompleteFailure = true; break; } - case IDLE: - case PENDING: + case PROCESSING: { - _failure = failure; - _state = State.ABORTED; - abort = true; + // Another thread is processing, so we just tell it the state and let it handle everything + _failure = cause; + _aborted = true; break; } - case PROCESSING: - case CALLED: + case PROCESSING_CALLED: { - _failure = failure; - _state = State.ABORTED; + // Another thread is processing, but we have already succeeded or failed. + if (_failure == null) + _failure = cause; + else + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); + _aborted = true; break; } - default: - throw new IllegalStateException(toString()); + case PENDING: + { + // We are waiting for the callback, so we can only call onAbort and then keep waiting + onAbort = true; + _failure = new AbortingException(cause); + _aborted = true; + break; + } + + case COMPLETE, CLOSED: + { + // too late + ExceptionUtil.addSuppressedIfNotAssociated(_failure, cause); + return false; + } } } - if (abort) - onCompleteFailure(failure); + if (onAbortDoCompleteFailure) + doOnAbortedOnFailureOnCompleted(cause); + else if (onAbort) + doOnAbortedOnFailureIfNotPendingDoCompleted(cause); + + return true; } /** @@ -561,7 +819,7 @@ public boolean isClosed() { try (AutoLock ignored = _lock.lock()) { - return _state == State.CLOSED; + return _state == State.CLOSED || _failure instanceof ClosedException; } } @@ -572,7 +830,7 @@ public boolean isFailed() { try (AutoLock ignored = _lock.lock()) { - return _state == State.FAILED; + return _failure != null; } } @@ -585,7 +843,7 @@ public boolean isSucceeded() { try (AutoLock ignored = _lock.lock()) { - return _state == State.SUCCEEDED; + return _state == State.COMPLETE && _failure == null; } } @@ -596,7 +854,7 @@ public boolean isAborted() { try (AutoLock ignored = _lock.lock()) { - return _state == State.ABORTED; + return _aborted; } } @@ -618,11 +876,10 @@ public boolean reset() case IDLE: return true; - case SUCCEEDED: - case FAILED: + case COMPLETE: _state = State.IDLE; _failure = null; - _iterate = false; + _reprocess = false; return true; default: @@ -634,6 +891,31 @@ public boolean reset() @Override public String toString() { - return String.format("%s@%x[%s]", getClass().getSimpleName(), hashCode(), _state); + try (AutoLock ignored = _lock.lock()) + { + return String.format("%s@%x[%s, %b, %s]", getClass().getSimpleName(), hashCode(), _state, _aborted, _failure); + } + } + + private static class ClosedException extends Exception + { + ClosedException() + { + super("Closed"); + } + + ClosedException(Throwable suppressed) + { + this(); + ExceptionUtil.addSuppressedIfNotAssociated(this, suppressed); + } + } + + private static class AbortingException extends Exception + { + AbortingException(Throwable cause) + { + super(cause.getMessage(), cause); + } } } diff --git a/jetty-core/jetty-util/src/test/java/org/eclipse/jetty/util/IteratingCallbackTest.java b/jetty-core/jetty-util/src/test/java/org/eclipse/jetty/util/IteratingCallbackTest.java index 00f6b007b7dd..cd04a884f46f 100644 --- a/jetty-core/jetty-util/src/test/java/org/eclipse/jetty/util/IteratingCallbackTest.java +++ b/jetty-core/jetty-util/src/test/java/org/eclipse/jetty/util/IteratingCallbackTest.java @@ -13,19 +13,37 @@ package org.eclipse.jetty.util; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicMarkableReference; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import org.awaitility.Awaitility; import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler; import org.eclipse.jetty.util.thread.Scheduler; +import org.hamcrest.Matchers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; - +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class IteratingCallbackTest @@ -202,46 +220,31 @@ protected Action process() { processed++; - switch (i--) + return switch (i--) { - case 5: + case 5, 2 -> + { succeeded(); - return Action.SCHEDULED; - - case 4: + yield Action.SCHEDULED; + } + case 4, 1 -> + { scheduler.schedule(successTask, 5, TimeUnit.MILLISECONDS); - return Action.SCHEDULED; - - case 3: - scheduler.schedule(new Runnable() - { - @Override - public void run() - { - idle.countDown(); - } - }, 5, TimeUnit.MILLISECONDS); - return Action.IDLE; - - case 2: - succeeded(); - return Action.SCHEDULED; - - case 1: - scheduler.schedule(successTask, 5, TimeUnit.MILLISECONDS); - return Action.SCHEDULED; - - case 0: - return Action.SUCCEEDED; - - default: - throw new IllegalStateException(); - } + yield Action.SCHEDULED; + } + case 3 -> + { + scheduler.schedule(idle::countDown, 5, TimeUnit.MILLISECONDS); + yield Action.IDLE; + } + case 0 -> Action.SUCCEEDED; + default -> throw new IllegalStateException(); + }; } }; cb.iterate(); - idle.await(10, TimeUnit.SECONDS); + assertTrue(idle.await(10, TimeUnit.SECONDS)); assertTrue(cb.isIdle()); cb.iterate(); @@ -252,25 +255,21 @@ public void run() @Test public void testCloseDuringProcessingReturningScheduled() throws Exception { - testCloseDuringProcessing(IteratingCallback.Action.SCHEDULED); - } - - @Test - public void testCloseDuringProcessingReturningSucceeded() throws Exception - { - testCloseDuringProcessing(IteratingCallback.Action.SUCCEEDED); - } - - private void testCloseDuringProcessing(final IteratingCallback.Action action) throws Exception - { + final CountDownLatch abortLatch = new CountDownLatch(1); final CountDownLatch failureLatch = new CountDownLatch(1); IteratingCallback callback = new IteratingCallback() { @Override - protected Action process() throws Exception + protected Action process() { close(); - return action; + return Action.SCHEDULED; + } + + @Override + protected void onAborted(Throwable cause) + { + abortLatch.countDown(); } @Override @@ -282,27 +281,45 @@ protected void onCompleteFailure(Throwable cause) callback.iterate(); - assertTrue(failureLatch.await(5, TimeUnit.SECONDS)); + assertFalse(failureLatch.await(100, TimeUnit.MILLISECONDS)); + assertTrue(abortLatch.await(1000000000, TimeUnit.SECONDS)); + assertTrue(callback.isClosed()); + + callback.succeeded(); + assertTrue(failureLatch.await(1, TimeUnit.SECONDS)); + assertTrue(callback.isFailed()); + assertTrue(callback.isClosed()); } - private abstract static class TestCB extends IteratingCallback + @Test + public void testCloseDuringProcessingReturningSucceeded() throws Exception { - protected Runnable successTask = new Runnable() + final CountDownLatch failureLatch = new CountDownLatch(1); + IteratingCallback callback = new IteratingCallback() { @Override - public void run() + protected Action process() { - succeeded(); + close(); + return Action.SUCCEEDED; } - }; - protected Runnable failTask = new Runnable() - { + @Override - public void run() + protected void onCompleteFailure(Throwable cause) { - failed(new Exception("testing failure")); + failureLatch.countDown(); } }; + + callback.iterate(); + + assertTrue(failureLatch.await(5, TimeUnit.SECONDS)); + } + + private abstract static class TestCB extends IteratingCallback + { + protected Runnable successTask = this::succeeded; + protected Runnable failTask = () -> failed(new Exception("testing failure")); protected CountDownLatch completed = new CountDownLatch(1); protected int processed = 0; @@ -320,8 +337,7 @@ public void onCompleteFailure(Throwable x) boolean waitForComplete() throws InterruptedException { - completed.await(10, TimeUnit.SECONDS); - return isSucceeded(); + return completed.await(10, TimeUnit.SECONDS) && isSucceeded(); } } @@ -390,57 +406,541 @@ protected void onCompleteFailure(Throwable cause) assertEquals(1, count.get()); - // Aborting should not iterate. icb.abort(new Exception()); assertTrue(ocfLatch.await(5, TimeUnit.SECONDS)); + assertTrue(icb.isFailed()); assertTrue(icb.isAborted()); assertEquals(1, count.get()); } @Test - public void testWhenProcessingAbortSerializesOnCompleteFailure() throws Exception + public void testWhenPendingAbortSerializesOnCompleteFailure() throws Exception { - AtomicInteger count = new AtomicInteger(); - CountDownLatch ocfLatch = new CountDownLatch(1); + AtomicReference aborted = new AtomicReference<>(); + CountDownLatch abortLatch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + AtomicMarkableReference completed = new AtomicMarkableReference<>(null, false); + IteratingCallback icb = new IteratingCallback() { @Override protected Action process() throws Throwable { - count.incrementAndGet(); - abort(new Exception()); - - // After calling abort, onCompleteFailure() must not be called yet. - assertFalse(ocfLatch.await(1, TimeUnit.SECONDS)); - return Action.SCHEDULED; } + @Override + protected void onAborted(Throwable cause) + { + aborted.set(cause); + ExceptionUtil.call(abortLatch::await, Throwable::printStackTrace); + } + @Override protected void onCompleteFailure(Throwable cause) { - ocfLatch.countDown(); + failure.set(cause); + } + + @Override + protected void onCompleted(Throwable causeOrNull) + { + completed.set(causeOrNull, true); + super.onCompleted(causeOrNull); } }; icb.iterate(); - assertEquals(1, count.get()); + assertThat(icb.toString(), containsString("[PENDING, false,")); - assertTrue(ocfLatch.await(5, TimeUnit.SECONDS)); - assertTrue(icb.isAborted()); + Throwable cause = new Throwable("test abort"); + new Thread(() -> icb.abort(cause)).start(); + + Awaitility.waitAtMost(5, TimeUnit.SECONDS).until(() -> icb.toString().contains("[PENDING, true,")); + Awaitility.waitAtMost(5, TimeUnit.SECONDS).until(() -> aborted.get() != null); - // Calling succeeded() won't cause further iterations. icb.succeeded(); - assertEquals(1, count.get()); + // We are now complete, but callbacks have not yet been done + assertThat(icb.toString(), containsString("[COMPLETE, true,")); + assertThat(failure.get(), nullValue()); + assertFalse(completed.isMarked()); + + abortLatch.countDown(); + Awaitility.waitAtMost(5, TimeUnit.SECONDS).until(completed::isMarked); + assertThat(failure.get(), sameInstance(cause)); + assertThat(completed.getReference(), sameInstance(cause)); + } + + public enum Event + { + PROCESSED, + ABORTED, + SUCCEEDED, + FAILED + } + + public static Stream> serializedEvents() + { + return Stream.of( + List.of(Event.PROCESSED, Event.ABORTED, Event.SUCCEEDED), + List.of(Event.PROCESSED, Event.SUCCEEDED, Event.ABORTED), + + List.of(Event.SUCCEEDED, Event.PROCESSED, Event.ABORTED), + List.of(Event.SUCCEEDED, Event.ABORTED, Event.PROCESSED), + + List.of(Event.ABORTED, Event.SUCCEEDED, Event.PROCESSED), + List.of(Event.ABORTED, Event.PROCESSED, Event.SUCCEEDED), + + List.of(Event.PROCESSED, Event.ABORTED, Event.FAILED), + List.of(Event.PROCESSED, Event.FAILED, Event.ABORTED), + + List.of(Event.FAILED, Event.PROCESSED, Event.ABORTED), + List.of(Event.FAILED, Event.ABORTED, Event.PROCESSED), + + List.of(Event.ABORTED, Event.FAILED, Event.PROCESSED), + List.of(Event.ABORTED, Event.PROCESSED, Event.FAILED) + ); + } + + @ParameterizedTest + @MethodSource("serializedEvents") + public void testSerializesProcessAbortCompletion(List events) throws Exception + { + AtomicReference aborted = new AtomicReference<>(); + CountDownLatch processingLatch = new CountDownLatch(1); + CountDownLatch abortLatch = new CountDownLatch(1); + AtomicReference failure = new AtomicReference<>(); + AtomicMarkableReference completed = new AtomicMarkableReference<>(null, false); + + + Throwable cause = new Throwable("test abort"); + + IteratingCallback icb = new IteratingCallback() + { + @Override + protected Action process() throws Throwable + { + abort(cause); + ExceptionUtil.call(processingLatch::await, Throwable::printStackTrace); + return Action.SCHEDULED; + } + + @Override + protected void onAborted(Throwable cause) + { + aborted.set(cause); + ExceptionUtil.call(abortLatch::await, Throwable::printStackTrace); + } + + @Override + protected void onCompleteFailure(Throwable cause) + { + failure.set(cause); + } + + @Override + protected void onCompleted(Throwable causeOrNull) + { + completed.set(causeOrNull, true); + super.onCompleted(causeOrNull); + } + }; + + new Thread(icb::iterate).start(); + + Awaitility.waitAtMost(5, TimeUnit.SECONDS).until(() -> icb.toString().contains("[PROCESSING, true,")); + + // we have aborted, but onAborted not yet called + assertThat(aborted.get(), nullValue()); + + int count = 0; + for (Event event : events) + { + switch (event) + { + case PROCESSED -> + { + processingLatch.countDown(); + // We can call aborted + Awaitility.waitAtMost(5, TimeUnit.SECONDS).pollInterval(10, TimeUnit.MILLISECONDS).until(() -> aborted.get() != null); + } + case ABORTED -> + { + abortLatch.countDown(); + Awaitility.waitAtMost(5, TimeUnit.SECONDS).pollInterval(10, TimeUnit.MILLISECONDS).until(() -> !icb.toString().contains("AbortingException")); + } + case SUCCEEDED -> icb.succeeded(); + + case FAILED -> icb.failed(new Throwable("failure")); + } + + if (++count < 3) + { + // Not complete yet + assertThat(failure.get(), nullValue()); + assertFalse(completed.isMarked()); + } + + // Extra aborts ignored + assertFalse(icb.abort(new Throwable("ignored"))); + } + + // When the callback is succeeded, the completion events can be called + Awaitility.waitAtMost(5, TimeUnit.SECONDS).pollInterval(10, TimeUnit.MILLISECONDS).until(completed::isMarked); + assertThat(failure.get(), sameInstance(cause)); + assertThat(completed.getReference(), sameInstance(cause)); + } + + @Test + public void testICBSuccess() throws Exception + { + TestIteratingCB callback = new TestIteratingCB(); + callback.iterate(); + callback.succeeded(); + assertTrue(callback._completed.await(1, TimeUnit.SECONDS)); + assertThat(callback._onFailure.get(), nullValue()); + assertThat(callback._completion.getReference(), Matchers.nullValue()); + assertTrue(callback._completion.isMarked()); + + // Everything now a noop + assertFalse(callback.abort(new Throwable())); + callback.failed(new Throwable()); + assertThat(callback._completion.getReference(), Matchers.nullValue()); + assertThat(callback._completed.getCount(), is(0L)); + + callback.checkNoBadCalls(); + } + + @Test + public void testICBFailure() throws Exception + { + Throwable failure = new Throwable(); + TestIteratingCB callback = new TestIteratingCB(); + callback.iterate(); + callback.failed(failure); + assertTrue(callback._completed.await(1, TimeUnit.SECONDS)); + assertThat(callback._onFailure.get(), sameInstance(failure)); + assertThat(callback._completion.getReference(), Matchers.sameInstance(failure)); + assertTrue(callback._completion.isMarked()); + + // Everything now a noop, other than suppression + callback.succeeded(); + Throwable late = new Throwable(); + assertFalse(callback.abort(late)); + assertFalse(ExceptionUtil.areNotAssociated(failure, late)); + assertThat(callback._completion.getReference(), Matchers.sameInstance(failure)); + assertThat(callback._completed.getCount(), is(0L)); + + callback.checkNoBadCalls(); + } + + @Test + public void testICBAbortSuccess() throws Exception + { + TestIteratingCB callback = new TestIteratingCB(); + callback.iterate(); + + Throwable abort = new Throwable(); + callback.abort(abort); + assertFalse(callback._completed.await(100, TimeUnit.MILLISECONDS)); + assertThat(callback._onFailure.get(), sameInstance(abort)); + assertThat(callback._completion.getReference(), Matchers.sameInstance(abort)); + assertFalse(callback._completion.isMarked()); + + callback.succeeded(); + assertThat(callback._completion.getReference(), Matchers.sameInstance(abort)); + assertThat(callback._completed.getCount(), is(0L)); + + Throwable late = new Throwable(); + callback.failed(late); + assertFalse(callback.abort(late)); + assertTrue(ExceptionUtil.areAssociated(abort, late)); + assertTrue(ExceptionUtil.areAssociated(callback._onFailure.get(), late)); + assertThat(callback._completion.getReference(), Matchers.sameInstance(abort)); + assertThat(callback._completed.getCount(), is(0L)); + + callback.checkNoBadCalls(); + } + + public static Stream abortTests() + { + List tests = new ArrayList<>(); + + for (IteratingCallback.State state : IteratingCallback.State.values()) + { + String name = state.name(); + + if (name.contains("PROCESSING")) + { + for (IteratingCallback.Action action : IteratingCallback.Action.values()) + { + if (name.contains("CALLED")) + { + if (action == IteratingCallback.Action.SCHEDULED) + { + tests.add(Arguments.of(name, action.toString(), Boolean.TRUE)); + tests.add(Arguments.of(name, action.toString(), Boolean.FALSE)); + } + } + else if (action == IteratingCallback.Action.SCHEDULED) + { + tests.add(Arguments.of(name, action.toString(), Boolean.TRUE)); + tests.add(Arguments.of(name, action.toString(), Boolean.FALSE)); + } + else + { + tests.add(Arguments.of(name, action.toString(), null)); + } + } + } + else if (name.equals("COMPLETE") || name.contains("PENDING")) + { + tests.add(Arguments.of(name, null, Boolean.TRUE)); + tests.add(Arguments.of(name, null, Boolean.FALSE)); + } + else + { + tests.add(Arguments.of(name, null, null)); + } + } + + return tests.stream(); + } + + @ParameterizedTest + @MethodSource("abortTests") + public void testAbortInEveryState(String state, String action, Boolean success) throws Exception + { + CountDownLatch processLatch = new CountDownLatch(1); + + AtomicReference onAbort = new AtomicReference<>(); + AtomicReference onFailure = new AtomicReference<>(null); + AtomicMarkableReference onCompleted = new AtomicMarkableReference<>(null, false); + + Throwable cause = new Throwable("abort"); + Throwable failure = new Throwable("failure"); + AtomicInteger badCalls = new AtomicInteger(0); + + IteratingCallback callback = new IteratingCallback() + { + @Override + protected Action process() throws Throwable + { + if (state.contains("CALLED")) + { + if (success) + succeeded(); + else + failed(failure); + } + + if (state.contains("PENDING")) + return Action.SCHEDULED; + + if (state.equals("COMPLETE")) + { + if (success) + return Action.SUCCEEDED; + failed(new Throwable("Complete Failure")); + return Action.SCHEDULED; + } + + if (state.equals("CLOSED")) + { + close(); + return Action.SUCCEEDED; + } + + processLatch.await(); + return IteratingCallback.Action.valueOf(action); + } + + @Override + protected void onFailure(Throwable cause) + { + if (!onFailure.compareAndSet(null, cause)) + badCalls.incrementAndGet(); + } + + @Override + protected void onAborted(Throwable cause) + { + if (!onAbort.compareAndSet(null, cause)) + badCalls.incrementAndGet(); + } + + @Override + protected void onCompleted(Throwable causeOrNull) + { + onCompleted.set(causeOrNull, true); + super.onCompleted(causeOrNull); + } + }; + + if (!state.equals("IDLE")) + { + new Thread(callback::iterate).start(); + } + + Awaitility.waitAtMost(5, TimeUnit.SECONDS).pollInterval(10, TimeUnit.MILLISECONDS).until(() -> callback.toString().contains(state)); + assertThat(callback.toString(), containsString("[" + state + ",")); + onAbort.set(null); + + if (success == Boolean.FALSE && (state.equals("COMPLETE") || state.equals("CLOSED"))) + { + // We must be failed already + assertThat(onFailure.get(), notNullValue()); + } + + boolean aborted = callback.abort(cause); + + // Check abort in completed state + if (state.equals("COMPLETE") || state.equals("CLOSED")) + { + assertThat(aborted, is(false)); + assertThat(onAbort.get(), nullValue()); + assertTrue(onCompleted.isMarked()); + if (success == Boolean.TRUE) + assertThat(onCompleted.getReference(), nullValue()); + else + assertThat(onCompleted.getReference(), notNullValue()); + return; + } + + // Check abort in non completed state + assertThat(aborted, is(true)); + + if (state.contains("PROCESSING")) + { + processLatch.countDown(); + + Awaitility.waitAtMost(5, TimeUnit.SECONDS).pollInterval(10, TimeUnit.MILLISECONDS).until(() -> !callback.toString().contains("PROCESSING")); + + if (action.equals("SCHEDULED")) + { + if (success) + { + callback.succeeded(); + } + else + { + Throwable failureAfterAbort = new Throwable("failure after abort"); + callback.failed(failureAfterAbort); + assertThat(onFailure.get(), not(sameInstance(failureAfterAbort))); + assertTrue(ExceptionUtil.areAssociated(onFailure.get(), failureAfterAbort)); + } + } + } + else if (state.contains("PENDING")) + { + if (success) + callback.succeeded(); + else + callback.failed(new Throwable("failure after abort")); + } + + assertTrue(onCompleted.isMarked()); + + if (state.contains("CALLED") && !success) + { + assertThat(onCompleted.getReference(), sameInstance(failure)); + assertThat(onAbort.get(), sameInstance(failure)); + } + else + { + assertThat(onCompleted.getReference(), sameInstance(cause)); + assertThat(onAbort.get(), sameInstance(cause)); + } + + assertThat(badCalls.get(), is(0)); + } + + private static class TestIteratingCB extends IteratingCallback + { + final AtomicInteger _count; + final AtomicInteger _badCalls = new AtomicInteger(0); + final AtomicBoolean _onSuccess = new AtomicBoolean(); + final AtomicReference _onFailure = new AtomicReference<>(); + final AtomicMarkableReference _completion = new AtomicMarkableReference<>(null, false); + final CountDownLatch _completed = new CountDownLatch(2); + + private TestIteratingCB() + { + this(1); + } + + private TestIteratingCB(int count) + { + _count = new AtomicInteger(count); + } + + @Override + protected Action process() + { + return _count.getAndDecrement() == 0 ? Action.SUCCEEDED : Action.SCHEDULED; + } + + @Override + protected void onAborted(Throwable cause) + { + _completion.compareAndSet(null, cause, false, false); + } + + @Override + protected void onSuccess() + { + if (!_onSuccess.compareAndSet(false, true)) + _badCalls.incrementAndGet(); + } + + @Override + protected void onFailure(Throwable cause) + { + if (!_onFailure.compareAndSet(null, cause)) + _badCalls.incrementAndGet(); + } + + @Override + protected void onCompleteFailure(Throwable cause) + { + if (_completion.compareAndSet(null, cause, false, true)) + _completed.countDown(); + + Throwable failure = _completion.getReference(); + if (failure != null && _completion.compareAndSet(failure, failure, false, true)) + _completed.countDown(); + } + + @Override + protected void onCompleteSuccess() + { + if (_completion.compareAndSet(null, null, false, true)) + _completed.countDown(); + } + + @Override + protected void onCompleted(Throwable causeOrNull) + { + if (_completion.isMarked()) + _badCalls.incrementAndGet(); + super.onCompleted(causeOrNull); + _completed.countDown(); + } + + public void checkNoBadCalls() + { + assertThat(_badCalls.get(), is(0)); + } } @Test public void testOnSuccessCalledDespiteISE() throws Exception { CountDownLatch latch = new CountDownLatch(1); + AtomicReference aborted = new AtomicReference<>(); IteratingCallback icb = new IteratingCallback() { @Override @@ -451,13 +951,22 @@ protected Action process() } @Override - protected void onSuccess() + protected void onAborted(Throwable cause) + { + aborted.set(cause); + super.onAborted(cause); + } + + @Override + protected void onCompleted(Throwable causeOrNull) { + super.onCompleted(causeOrNull); latch.countDown(); } }; - assertThrows(IllegalStateException.class, icb::iterate); + icb.iterate(); assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertThat(aborted.get(), instanceOf(IllegalStateException.class)); } } diff --git a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/util/TransformingFlusher.java b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/util/TransformingFlusher.java index edd5ee805ecd..c9c70ee16e1d 100644 --- a/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/util/TransformingFlusher.java +++ b/jetty-core/jetty-websocket/jetty-websocket-core-common/src/main/java/org/eclipse/jetty/websocket/core/util/TransformingFlusher.java @@ -180,7 +180,7 @@ protected void onCompleteFailure(Throwable t) notifyCallbackFailure(current.callback, t); current = null; } - onFailure(t); + this.onCompleteFailure(t); } } diff --git a/jetty-ee11/jetty-ee11-proxy/src/main/java/org/eclipse/jetty/ee11/proxy/AsyncProxyServlet.java b/jetty-ee11/jetty-ee11-proxy/src/main/java/org/eclipse/jetty/ee11/proxy/AsyncProxyServlet.java index 00350bc2491a..17a7dfe82207 100644 --- a/jetty-ee11/jetty-ee11-proxy/src/main/java/org/eclipse/jetty/ee11/proxy/AsyncProxyServlet.java +++ b/jetty-ee11/jetty-ee11-proxy/src/main/java/org/eclipse/jetty/ee11/proxy/AsyncProxyServlet.java @@ -192,9 +192,9 @@ protected void onRequestContent(HttpServletRequest request, Request proxyRequest } @Override - public void failed(Throwable x) + public void onCompleteFailure(Throwable x) { - super.failed(x); + super.onCompleteFailure(x); onError(x); } }