Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove defensive Task[] copy from non-generic Task.WhenAll #81065

Merged
merged 1 commit into from
Jan 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 181 additions & 116 deletions src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
//
// =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Tracing;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Runtime.Versioning;

namespace System.Threading.Tasks
Expand Down Expand Up @@ -5729,36 +5731,36 @@ protected override void Cleanup()
/// </exception>
public static Task WhenAll(IEnumerable<Task> tasks)
{
// Skip a List allocation/copy if tasks is a collection
if (tasks is null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
}

if (tasks is ICollection<Task> taskCollection)
{
// Take a more efficient path if tasks is actually an array
if (tasks is Task[] taskArray)
{
return WhenAll(taskArray);
return WhenAll((ReadOnlySpan<Task>)taskArray);
}

int index = 0;
taskArray = new Task[taskCollection.Count];
foreach (Task task in tasks)
if (tasks is List<Task> taskList)
{
if (task == null) ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
taskArray[index++] = task;
return WhenAll(CollectionsMarshal.AsSpan(taskList));
}
return InternalWhenAll(taskArray);
}

// Do some argument checking and convert tasks to a List (and later an array).
if (tasks == null) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
List<Task> taskList = new List<Task>();
foreach (Task task in tasks)
taskArray = new Task[taskCollection.Count];
taskCollection.CopyTo(taskArray, 0);
return WhenAll((ReadOnlySpan<Task>)taskArray);
}
else
{
if (task == null) ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
taskList.Add(task);
var taskList = new List<Task>();
foreach (Task task in tasks)
{
taskList.Add(task);
}
return WhenAll(CollectionsMarshal.AsSpan(taskList));
}

// Delegate the rest to InternalWhenAll()
return InternalWhenAll(taskList.ToArray());
}

/// <summary>
Expand Down Expand Up @@ -5790,149 +5792,212 @@ public static Task WhenAll(IEnumerable<Task> tasks)
/// </exception>
public static Task WhenAll(params Task[] tasks)
{
// Do some argument checking and make a defensive copy of the tasks array
if (tasks == null) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);

int taskCount = tasks.Length;
if (taskCount == 0) return InternalWhenAll(tasks); // Small optimization in the case of an empty array.

Task[] tasksCopy = new Task[taskCount];
for (int i = 0; i < taskCount; i++)
if (tasks is null)
{
Task task = tasks[i];
if (task == null) ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
tasksCopy[i] = task;
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.tasks);
}

// The rest can be delegated to InternalWhenAll()
return InternalWhenAll(tasksCopy);
return WhenAll((ReadOnlySpan<Task>)tasks);
}

// Some common logic to support WhenAll() methods
// tasks should be a defensive copy.
private static Task InternalWhenAll(Task[] tasks)
{
Debug.Assert(tasks != null, "Expected a non-null tasks array");
return (tasks.Length == 0) ? // take shortcut if there are no tasks upon which to wait
Task.CompletedTask :
new WhenAllPromise(tasks);
}
/// <summary>
/// Creates a task that will complete when all of the supplied tasks have completed.
/// </summary>
/// <param name="tasks">The tasks to wait on for completion.</param>
/// <returns>A task that represents the completion of all of the supplied tasks.</returns>
/// <remarks>
/// <para>
/// If any of the supplied tasks completes in a faulted state, the returned task will also complete in a Faulted state,
/// where its exceptions will contain the aggregation of the set of unwrapped exceptions from each of the supplied tasks.
/// </para>
/// <para>
/// If none of the supplied tasks faulted but at least one of them was canceled, the returned task will end in the Canceled state.
/// </para>
/// <para>
/// If none of the tasks faulted and none of the tasks were canceled, the resulting task will end in the RanToCompletion state.
/// </para>
/// <para>
/// If the supplied span contains no tasks, the returned task will immediately transition to a RanToCompletion
/// state before it's returned to the caller.
/// </para>
/// </remarks>
/// <exception cref="System.ArgumentException">The <paramref name="tasks"/> array contained a null task.</exception>
internal static Task WhenAll(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Make this public.
tasks.Length != 0 ? new WhenAllPromise(tasks) : CompletedTask;

// A Task that gets completed when all of its constituent tasks complete.
// Completion logic will analyze the antecedents in order to choose completion status.
// This type allows us to replace this logic:
// Task promise = new Task(...);
// Action<Task> completionAction = delegate { <completion logic>};
// TaskFactory.CommonCWAllLogic(tasksCopy).AddCompletionAction(completionAction);
// return promise;
// which involves several allocations, with this logic:
// return new WhenAllPromise(tasksCopy);
// which saves a couple of allocations and enables debugger notification specialization.
//
// Used in InternalWhenAll(Task[])
/// <summary>A Task that gets completed when all of its constituent tasks complete.</summary>
private sealed class WhenAllPromise : Task, ITaskCompletionAction
{
/// <summary>
/// Stores all of the constituent tasks. Tasks clear themselves out of this
/// array as they complete, but only if they don't have their wait notification bit set.
/// </summary>
private readonly Task?[] m_tasks;
/// <summary>Either a single faulted/canceled task, or a list of faulted/canceled tasks.</summary>
private object? _failedOrCanceled;
/// <summary>The number of tasks remaining to complete.</summary>
private int m_count;
private int _remainingToComplete;

internal WhenAllPromise(Task[] tasks)
internal WhenAllPromise(ReadOnlySpan<Task> tasks)
{
Debug.Assert(tasks != null, "Expected a non-null task array");
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert(tasks.Length > 0, "Expected a non-zero length task array");
Debug.Assert(tasks.Length != 0, "Expected a non-zero length task array");

// Throw if any of the provided tasks is null. This is best effort to inform the caller
// they've made a mistake. If between the time we check for nulls and the time we hook
// up callbacks one of the entries is changed from non-null to null, we'll just ignore
// the null at that point; any such use (e.g. calling WhenAll with an array that's mutated
// concurrently with the synchronous call to WhenAll) is erroneous.
foreach (Task task in tasks)
{
if (task is null)
{
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
}
}

if (TplEventSource.Log.IsEnabled())
TplEventSource.Log.TraceOperationBegin(this.Id, "Task.WhenAll", 0);
{
TplEventSource.Log.TraceOperationBegin(Id, "Task.WhenAll", 0);
}

if (s_asyncDebuggingEnabled)
{
AddToActiveTasks(this);
}

m_tasks = tasks;
m_count = tasks.Length;
_remainingToComplete = tasks.Length;

foreach (Task task in tasks)
{
if (task.IsCompleted) this.Invoke(task); // short-circuit the completion action, if possible
else task.AddCompletionAction(this); // simple completion action
if (task is null || task.IsCompleted)
{
Invoke(task); // short-circuit the completion action, if possible
}
else
{
task.AddCompletionAction(this); // simple completion action
}
}
}

public void Invoke(Task completedTask)
public void Invoke(Task? completedTask)
{
if (TplEventSource.Log.IsEnabled())
TplEventSource.Log.TraceOperationRelation(this.Id, CausalityRelation.Join);

// Decrement the count, and only continue to complete the promise if we're the last one.
if (Interlocked.Decrement(ref m_count) == 0)
{
// Set up some accounting variables
List<ExceptionDispatchInfo>? observedExceptions = null;
Task? canceledTask = null;
TplEventSource.Log.TraceOperationRelation(Id, CausalityRelation.Join);
}

// Loop through antecedents:
// If any one of them faults, the result will be faulted
// If none fault, but at least one is canceled, the result will be canceled
// If none fault or are canceled, then result will be RanToCompletion
for (int i = 0; i < m_tasks.Length; i++)
if (completedTask is not null)
{
if (completedTask.IsWaitNotificationEnabled)
{
Task? task = m_tasks[i];
Debug.Assert(task != null, "Constituent task in WhenAll should never be null");
SetNotificationForWaitCompletion(enabled: true);
}

if (task.IsFaulted)
{
observedExceptions ??= new List<ExceptionDispatchInfo>();
observedExceptions.AddRange(task.GetExceptionDispatchInfos());
}
else if (task.IsCanceled)
if (!completedTask.IsCompletedSuccessfully)
{
// Try to store the completed task as the first that's failed or faulted.
if (Interlocked.CompareExchange(ref _failedOrCanceled, completedTask, null) != null)
{
canceledTask ??= task; // use the first task that's canceled
}
// There was already something there.
while (true)
{
object? failedOrCanceled = _failedOrCanceled;
Debug.Assert(failedOrCanceled is not null);

// Regardless of completion state, if the task has its debug bit set, transfer it to the
// WhenAll task. We must do this before we complete the task.
if (task.IsWaitNotificationEnabled) this.SetNotificationForWaitCompletion(enabled: true);
else m_tasks[i] = null; // avoid holding onto tasks unnecessarily
// If it was a list, add it to the list.
if (_failedOrCanceled is List<Task> list)
{
lock (list)
{
list.Add(completedTask);
}
break;
}

// Otherwise, it was a Task. Create a new list containing that task and this one, and store it in.
Debug.Assert(failedOrCanceled is Task, $"Expected Task, got {failedOrCanceled}");
if (Interlocked.CompareExchange(ref _failedOrCanceled, new List<Task> { (Task)failedOrCanceled, completedTask }, failedOrCanceled) == failedOrCanceled)
{
break;
}

// We lost the race, which means we should loop around one more time and it'll be a list.
Debug.Assert(_failedOrCanceled is List<Task>);
}
}
}
}

if (observedExceptions != null)
// Decrement the count, and only continue to complete the promise if we're the last one.
if (Interlocked.Decrement(ref _remainingToComplete) == 0)
{
object? failedOrCanceled = _failedOrCanceled;
if (failedOrCanceled is null)
{
Debug.Assert(observedExceptions.Count > 0, "Expected at least one exception");
if (TplEventSource.Log.IsEnabled())
{
TplEventSource.Log.TraceOperationEnd(Id, AsyncCausalityStatus.Completed);
}

// We don't need to TraceOperationCompleted here because TrySetException will call Finish and we'll log it there
if (s_asyncDebuggingEnabled)
{
RemoveFromActiveTasks(this);
}

TrySetException(observedExceptions);
}
else if (canceledTask != null)
{
TrySetCanceled(canceledTask.CancellationToken, canceledTask.GetCancellationExceptionDispatchInfo());
bool completed = TrySetResult();
Debug.Assert(completed);
}
else
{
if (TplEventSource.Log.IsEnabled())
TplEventSource.Log.TraceOperationEnd(this.Id, AsyncCausalityStatus.Completed);
// Set up some accounting variables
List<ExceptionDispatchInfo>? observedExceptions = null;
Task? canceledTask = null;

if (s_asyncDebuggingEnabled)
RemoveFromActiveTasks(this);
void HandleTask(Task task)
{
if (task.IsFaulted)
{
(observedExceptions ??= new()).AddRange(task.GetExceptionDispatchInfos());
}
else if (task.IsCanceled)
{
canceledTask ??= task; // use the first task that's canceled
}
}

// Loop through the completed or faulted tasks:
// If any one of them faults, the result will be faulted
// If none fault, but at least one is canceled, the result will be canceled
if (failedOrCanceled is List<Task> list)
{
foreach (Task task in list)
{
HandleTask(task);
}
}
else
{
Debug.Assert(failedOrCanceled is Task);
HandleTask((Task)failedOrCanceled);
}

if (observedExceptions != null)
{
Debug.Assert(observedExceptions.Count > 0, "Expected at least one exception");

TrySetResult();
// We don't need to TraceOperationCompleted here because TrySetException will call Finish and we'll log it there

TrySetException(observedExceptions);
}
else if (canceledTask != null)
{
TrySetCanceled(canceledTask.CancellationToken, canceledTask.GetCancellationExceptionDispatchInfo());
}
}

Debug.Assert(IsCompleted);
}
Debug.Assert(m_count >= 0, "Count should never go below 0");

Debug.Assert(_remainingToComplete >= 0, "Count should never go below 0");
}

public bool InvokeMayRunArbitraryCode => true;

/// <summary>
/// Returns whether we should notify the debugger of a wait completion. This returns
/// true iff at least one constituent task has its bit set.
/// </summary>
private protected override bool ShouldNotifyDebuggerOfWaitCompletion =>
base.ShouldNotifyDebuggerOfWaitCompletion &&
AnyTaskRequiresNotifyDebuggerOfWaitCompletion(m_tasks);
}

/// <summary>
Expand Down