diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs index c2bd5ab8f3f90..c29c5e759f2c1 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Sources/ManualResetValueTaskSourceCore.cs @@ -8,7 +8,7 @@ namespace System.Threading.Tasks.Sources { /// Provides the core logic for implementing a manual-reset or . - /// + /// Specifies the type of results of the operation represented by this instance. [StructLayout(LayoutKind.Auto)] public struct ManualResetValueTaskSourceCore { @@ -20,38 +20,44 @@ public struct ManualResetValueTaskSourceCore private Action? _continuation; /// State to pass to . private object? _continuationState; - /// to flow to the callback, or null if no flowing is required. - private ExecutionContext? _executionContext; /// - /// A "captured" or with which to invoke the callback, - /// or null if no special context is required. + /// Null if no special context was found. + /// ExecutionContext if one was captured due to needing to be flowed. + /// A scheduler (TaskScheduler or SynchronizationContext) if one was captured and needs to be used for callback scheduling. + /// Or a CapturedContext if there's both an ExecutionContext and a scheduler. + /// The most common and the fast path case to optimize for is null. /// private object? _capturedContext; - /// Whether the current operation has completed. - private bool _completed; - /// The result with which the operation succeeded, or the default value if it hasn't yet completed or failed. - private TResult? _result; /// The exception with which the operation failed, or null if it hasn't yet completed or completed successfully. private ExceptionDispatchInfo? _error; + /// The result with which the operation succeeded, or the default value if it hasn't yet completed or failed. + private TResult? _result; /// The current version of this value, used to help prevent misuse. private short _version; + /// Whether the current operation has completed. + private bool _completed; + /// Whether to force continuations to run asynchronously. + private bool _runContinuationsAsynchronously; /// Gets or sets whether to force continuations to run asynchronously. /// Continuations may run asynchronously if this is false, but they'll never run synchronously if this is true. - public bool RunContinuationsAsynchronously { get; set; } + public bool RunContinuationsAsynchronously + { + get => _runContinuationsAsynchronously; + set => _runContinuationsAsynchronously = value; + } /// Resets to prepare for the next operation. public void Reset() { // Reset/update state for the next use/await of this instance. _version++; - _completed = false; - _result = default; - _error = null; - _executionContext = null; - _capturedContext = null; _continuation = null; _continuationState = null; + _capturedContext = null; + _error = null; + _result = default; + _completed = false; } /// Completes with a successful result. @@ -79,8 +85,8 @@ public ValueTaskSourceStatus GetStatus(short token) { ValidateToken(token); return - Volatile.Read(ref _continuation) == null || !_completed ? ValueTaskSourceStatus.Pending : - _error == null ? ValueTaskSourceStatus.Succeeded : + Volatile.Read(ref _continuation) is null || !_completed ? ValueTaskSourceStatus.Pending : + _error is null ? ValueTaskSourceStatus.Succeeded : _error.SourceException is OperationCanceledException ? ValueTaskSourceStatus.Canceled : ValueTaskSourceStatus.Faulted; } @@ -92,22 +98,18 @@ public TResult GetResult(short token) { if (token != _version || !_completed || _error is not null) { - ThrowForFailedGetResult(token); + ThrowForFailedGetResult(); } return _result!; } + /// Throws an exception in response to a failed . [StackTraceHidden] - private void ThrowForFailedGetResult(short token) + private void ThrowForFailedGetResult() { - if (token != _version || !_completed) - { - ThrowHelper.ThrowInvalidOperationException(); - } - _error?.Throw(); - Debug.Fail($"{nameof(ThrowForFailedGetResult)} should never get here"); + throw new InvalidOperationException(); // not using ThrowHelper.ThrowInvalidOperationException so that the JIT sees ThrowForFailedGetResult as always throwing } /// Schedules the continuation action for this operation. @@ -117,28 +119,34 @@ private void ThrowForFailedGetResult(short token) /// The flags describing the behavior of the continuation. public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) { - ArgumentNullException.ThrowIfNull(continuation); - + if (continuation is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.continuation); + } ValidateToken(token); if ((flags & ValueTaskSourceOnCompletedFlags.FlowExecutionContext) != 0) { - _executionContext = ExecutionContext.Capture(); + _capturedContext = ExecutionContext.Capture(); } if ((flags & ValueTaskSourceOnCompletedFlags.UseSchedulingContext) != 0) { - SynchronizationContext? sc = SynchronizationContext.Current; - if (sc != null && sc.GetType() != typeof(SynchronizationContext)) + if (SynchronizationContext.Current is SynchronizationContext sc && + sc.GetType() != typeof(SynchronizationContext)) { - _capturedContext = sc; + _capturedContext = _capturedContext is null ? + sc : + new CapturedSchedulerAndExecutionContext(sc, (ExecutionContext)_capturedContext); } else { TaskScheduler ts = TaskScheduler.Current; if (ts != TaskScheduler.Default) { - _capturedContext = ts; + _capturedContext = _capturedContext is null ? + ts : + new CapturedSchedulerAndExecutionContext(ts, (ExecutionContext)_capturedContext); } } } @@ -150,47 +158,41 @@ public void OnCompleted(Action continuation, object? state, short token // awaited twice concurrently), _continuationState might get erroneously overwritten. // To minimize the chances of that, we check preemptively whether _continuation // is already set to something other than the completion sentinel. - - object? oldContinuation = _continuation; - if (oldContinuation == null) + object? storedContinuation = _continuation; + if (storedContinuation is null) { _continuationState = state; - oldContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + storedContinuation = Interlocked.CompareExchange(ref _continuation, continuation, null); + if (storedContinuation is null) + { + // Operation hadn't already completed, so we're done. The continuation will be + // invoked when SetResult/Exception is called at some later point. + return; + } } - if (oldContinuation != null) + // Operation already completed, so we need to queue the supplied callback. + // At this point the storedContinuation should be the sentinal; if it's not, the instance was misused. + Debug.Assert(storedContinuation is not null, $"{nameof(storedContinuation)} is null"); + if (!ReferenceEquals(storedContinuation, ManualResetValueTaskSourceCoreShared.s_sentinel)) { - // Operation already completed, so we need to queue the supplied callback. - if (!ReferenceEquals(oldContinuation, ManualResetValueTaskSourceCoreShared.s_sentinel)) - { - ThrowHelper.ThrowInvalidOperationException(); - } + ThrowHelper.ThrowInvalidOperationException(); + } - switch (_capturedContext) - { - case null: - if (_executionContext != null) - { - ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true); - } - else - { - ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true); - } - break; - - case SynchronizationContext sc: - sc.Post(static s => - { - var tuple = (TupleSlim, object?>)s!; - tuple.Item1(tuple.Item2); - }, new TupleSlim, object?>(continuation, state)); - break; - - case TaskScheduler ts: - Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); - break; - } + object? capturedContext = _capturedContext; + switch (capturedContext) + { + case null: + ThreadPool.UnsafeQueueUserWorkItem(continuation, state, preferLocal: true); + break; + + case ExecutionContext: + ThreadPool.QueueUserWorkItem(continuation, state, preferLocal: true); + break; + + default: + ManualResetValueTaskSourceCoreShared.ScheduleCapturedContext(capturedContext, continuation, state); + break; } } @@ -213,62 +215,125 @@ private void SignalCompletion() } _completed = true; - if (Volatile.Read(ref _continuation) is null && Interlocked.CompareExchange(ref _continuation, ManualResetValueTaskSourceCoreShared.s_sentinel, null) is null) - { - return; - } + Action? continuation = + Volatile.Read(ref _continuation) ?? + Interlocked.CompareExchange(ref _continuation, ManualResetValueTaskSourceCoreShared.s_sentinel, null); - Debug.Assert(_continuation is not null); - - if (_executionContext is null) + if (continuation is not null) { - if (_capturedContext is null) + Debug.Assert(continuation is not null, $"{nameof(continuation)} is null"); + + object? context = _capturedContext; + if (context is null) { - if (RunContinuationsAsynchronously) + if (_runContinuationsAsynchronously) { - ThreadPool.UnsafeQueueUserWorkItem(_continuation, _continuationState, preferLocal: true); + ThreadPool.UnsafeQueueUserWorkItem(continuation, _continuationState, preferLocal: true); } else { - _continuation(_continuationState); + continuation(_continuationState); } } + else if (context is ExecutionContext or CapturedSchedulerAndExecutionContext) + { + ManualResetValueTaskSourceCoreShared.InvokeContinuationWithContext(context, continuation, _continuationState, _runContinuationsAsynchronously); + } else { - InvokeSchedulerContinuation(); + Debug.Assert(context is TaskScheduler or SynchronizationContext, $"context is {context}"); + ManualResetValueTaskSourceCoreShared.ScheduleCapturedContext(context, continuation, _continuationState); } } - else + } + } + + /// A tuple of both a non-null scheduler and a non-null ExecutionContext. + internal sealed class CapturedSchedulerAndExecutionContext + { + internal readonly object _scheduler; + internal readonly ExecutionContext _executionContext; + + public CapturedSchedulerAndExecutionContext(object scheduler, ExecutionContext executionContext) + { + Debug.Assert(scheduler is SynchronizationContext or TaskScheduler, $"{nameof(scheduler)} is {scheduler}"); + Debug.Assert(executionContext is not null, $"{nameof(executionContext)} is null"); + + _scheduler = scheduler; + _executionContext = executionContext; + } + } + + internal static class ManualResetValueTaskSourceCoreShared // separated out of generic to avoid unnecessary duplication + { + internal static readonly Action s_sentinel = CompletionSentinel; + + private static void CompletionSentinel(object? _) // named method to aid debugging + { + Debug.Fail("The sentinel delegate should never be invoked."); + ThrowHelper.ThrowInvalidOperationException(); + } + + internal static void ScheduleCapturedContext(object context, Action continuation, object? state) + { + Debug.Assert( + context is SynchronizationContext or TaskScheduler or CapturedSchedulerAndExecutionContext, + $"{nameof(context)} is {context}"); + + switch (context) { - InvokeContinuationWithContext(); + case SynchronizationContext sc: + ScheduleSynchronizationContext(sc, continuation, state); + break; + + case TaskScheduler ts: + ScheduleTaskScheduler(ts, continuation, state); + break; + + default: + CapturedSchedulerAndExecutionContext cc = (CapturedSchedulerAndExecutionContext)context; + if (cc._scheduler is SynchronizationContext ccsc) + { + ScheduleSynchronizationContext(ccsc, continuation, state); + } + else + { + Debug.Assert(cc._scheduler is TaskScheduler, $"{nameof(cc._scheduler)} is {cc._scheduler}"); + ScheduleTaskScheduler((TaskScheduler)cc._scheduler, continuation, state); + } + break; } + + static void ScheduleSynchronizationContext(SynchronizationContext sc, Action continuation, object? state) => + sc.Post(continuation.Invoke, state); + + static void ScheduleTaskScheduler(TaskScheduler scheduler, Action continuation, object? state) => + Task.Factory.StartNew(continuation, state, CancellationToken.None, TaskCreationOptions.DenyChildAttach, scheduler); } - private void InvokeContinuationWithContext() + internal static void InvokeContinuationWithContext(object capturedContext, Action continuation, object? continuationState, bool runContinuationsAsynchronously) { // This is in a helper as the error handling causes the generated asm // for the surrounding code to become less efficient (stack spills etc) // and it is an uncommon path. + Debug.Assert(continuation is not null, $"{nameof(continuation)} is null"); + Debug.Assert(capturedContext is ExecutionContext or CapturedSchedulerAndExecutionContext, $"{nameof(capturedContext)} is {capturedContext}"); - Debug.Assert(_continuation != null, $"Null {nameof(_continuation)}"); - Debug.Assert(_executionContext != null, $"Null {nameof(_executionContext)}"); - + // Capture the current EC. We'll switch over to the target EC and then restore back to this one. ExecutionContext? currentContext = ExecutionContext.CaptureForRestore(); - // Restore the captured ExecutionContext before executing anything. - ExecutionContext.Restore(_executionContext); - if (_capturedContext is null) + if (capturedContext is ExecutionContext ec) { - if (RunContinuationsAsynchronously) + ExecutionContext.RestoreInternal(ec); // Restore the captured ExecutionContext before executing anything. + if (runContinuationsAsynchronously) { try { - ThreadPool.QueueUserWorkItem(_continuation, _continuationState, preferLocal: true); + ThreadPool.QueueUserWorkItem(continuation, continuationState, preferLocal: true); } finally { - // Restore the current ExecutionContext. - ExecutionContext.RestoreInternal(currentContext); + ExecutionContext.RestoreInternal(currentContext); // Restore the current ExecutionContext. } } else @@ -279,7 +344,7 @@ private void InvokeContinuationWithContext() SynchronizationContext? syncContext = SynchronizationContext.Current; try { - _continuation(_continuationState); + continuation(continuationState); } catch (Exception ex) { @@ -290,64 +355,29 @@ private void InvokeContinuationWithContext() } finally { - // Set sync context back to what it was prior to coming in + // Set sync context back to what it was prior to coming in. + // Then restore the current ExecutionContext. SynchronizationContext.SetSynchronizationContext(syncContext); - // Restore the current ExecutionContext. ExecutionContext.RestoreInternal(currentContext); } // Now rethrow the exception; if there is one. edi?.Throw(); } - - return; - } - - try - { - InvokeSchedulerContinuation(); } - finally - { - // Restore the current ExecutionContext. - ExecutionContext.RestoreInternal(currentContext); - } - } - - /// - /// Invokes the continuation with the appropriate scheduler. - /// This assumes that if is not null we're already - /// running within that . - /// - private void InvokeSchedulerContinuation() - { - Debug.Assert(_capturedContext != null, $"Null {nameof(_capturedContext)}"); - Debug.Assert(_continuation != null, $"Null {nameof(_continuation)}"); - - switch (_capturedContext) + else { - case SynchronizationContext sc: - sc.Post(static s => - { - var state = (TupleSlim, object?>)s!; - state.Item1(state.Item2); - }, new TupleSlim, object?>(_continuation, _continuationState)); - break; - - case TaskScheduler ts: - Task.Factory.StartNew(_continuation, _continuationState, CancellationToken.None, TaskCreationOptions.DenyChildAttach, ts); - break; + CapturedSchedulerAndExecutionContext cc = (CapturedSchedulerAndExecutionContext)capturedContext; + ExecutionContext.Restore(cc._executionContext); // Restore the captured ExecutionContext before executing anything. + try + { + ScheduleCapturedContext(capturedContext, continuation, continuationState); + } + finally + { + ExecutionContext.RestoreInternal(currentContext); // Restore the current ExecutionContext. + } } } } - - internal static class ManualResetValueTaskSourceCoreShared // separated out of generic to avoid unnecessary duplication - { - internal static readonly Action s_sentinel = CompletionSentinel; - private static void CompletionSentinel(object? _) // named method to aid debugging - { - Debug.Fail("The sentinel delegate should never be invoked."); - ThrowHelper.ThrowInvalidOperationException(); - } - } } diff --git a/src/libraries/System.Private.CoreLib/src/System/ThrowHelper.cs b/src/libraries/System.Private.CoreLib/src/System/ThrowHelper.cs index 59539d2eb06eb..3382fa9571d71 100644 --- a/src/libraries/System.Private.CoreLib/src/System/ThrowHelper.cs +++ b/src/libraries/System.Private.CoreLib/src/System/ThrowHelper.cs @@ -818,6 +818,8 @@ private static string GetArgumentName(ExceptionArgument argument) return "function"; case ExceptionArgument.scheduler: return "scheduler"; + case ExceptionArgument.continuation: + return "continuation"; case ExceptionArgument.continuationAction: return "continuationAction"; case ExceptionArgument.continuationFunction: @@ -1142,6 +1144,7 @@ internal enum ExceptionArgument creationOptions, function, scheduler, + continuation, continuationAction, continuationFunction, tasks,