Skip to content
This repository has been archived by the owner on Jan 23, 2023. It is now read-only.
/ corefx Public archive

Expose\test IAsyncEnumerable.ConfigureAwait #33337

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/System.Threading.Tasks/ref/System.Threading.Tasks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ public void SetResult() { }
public void SetStateMachine(System.Runtime.CompilerServices.IAsyncStateMachine stateMachine) { }
public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : System.Runtime.CompilerServices.IAsyncStateMachine { }
}
public readonly struct ConfiguredAsyncEnumerable<T>
{
private readonly object _dummy;
public Enumerator GetAsyncEnumerator() { throw null; }
public readonly struct Enumerator
{
private readonly object _dummy;
public ConfiguredValueTaskAwaitable<bool> MoveNextAsync() { throw null; }
public T Current { get { throw null; } }
public ConfiguredValueTaskAwaitable DisposeAsync() { throw null; }
}
}
}
namespace System.Threading
{
Expand Down Expand Up @@ -149,6 +161,7 @@ public void SetResult(TResult result) { }
}
public static partial class TaskExtensions
{
public static System.Runtime.CompilerServices.ConfiguredAsyncEnumerable<T> ConfigureAwait<T>(this System.Collections.Generic.IAsyncEnumerable<T> source, bool continueOnCapturedContext) { throw null; }
public static System.Threading.Tasks.Task Unwrap(this System.Threading.Tasks.Task<System.Threading.Tasks.Task> task) { throw null; }
public static System.Threading.Tasks.Task<TResult> Unwrap<TResult>(this System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> task) { throw null; }
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Threading.Tasks.Sources;
using Xunit;

namespace System.Runtime.CompilerServices.Tests
{
public class ConfiguredAsyncEnumerableTests
{
[Fact]
public void ConfigureAwait_GetAsyncEnumerator_Default_Throws()
{
ConfiguredAsyncEnumerable<int> e = default;
Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());

e = ((IAsyncEnumerable<int>)null).ConfigureAwait(false);
Assert.Throws<NullReferenceException>(() => e.GetAsyncEnumerator());
}

[Fact]
public void ConfigureAwait_EnumeratorMembers_Default_Throws()
{
ConfiguredAsyncEnumerable<int>.Enumerator e = default;
Assert.Throws<NullReferenceException>(() => e.MoveNextAsync());
Assert.Throws<NullReferenceException>(() => e.Current);
Assert.Throws<NullReferenceException>(() => e.DisposeAsync());
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ConfigureAwait_AwaitMoveNextAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
{
var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
ConfiguredAsyncEnumerable<int>.Enumerator enumerator = enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator();
ConfiguredValueTaskAwaitable<bool>.ConfiguredValueTaskAwaiter moveNextAwaiter = enumerator.MoveNextAsync().GetAwaiter();
moveNextAwaiter.UnsafeOnCompleted(() => { });
Assert.Equal(
continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
enumerable.Flags);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public void ConfigureAwait_AwaitDisposeAsync_FlagsSetAppropriately(bool continueOnCapturedContext)
{
var enumerable = new TrackFlagsAsyncEnumerable() { Flags = 0 };
ConfiguredAsyncEnumerable<int>.Enumerator enumerator = enumerable.ConfigureAwait(continueOnCapturedContext).GetAsyncEnumerator();
ConfiguredValueTaskAwaitable.ConfiguredValueTaskAwaiter disposeAwaiter = enumerator.DisposeAsync().GetAwaiter();
disposeAwaiter.UnsafeOnCompleted(() => { });
Assert.Equal(
continueOnCapturedContext ? ValueTaskSourceOnCompletedFlags.UseSchedulingContext : ValueTaskSourceOnCompletedFlags.None,
enumerable.Flags);
}

[Fact]
public async Task ConfigureAwait_CanBeEnumeratedWithStandardPattern()
{
IAsyncEnumerable<int> asyncEnumerable = new EnumerableWithDelayToAsyncEnumerable<int>(Enumerable.Range(1, 10), 1);
int sum = 0;

ConfiguredAsyncEnumerable<int>.Enumerator e = asyncEnumerable.ConfigureAwait(false).GetAsyncEnumerator();
try
{
while (await e.MoveNextAsync())
{
sum += e.Current;
}
}
finally
{
await e.DisposeAsync();
}

Assert.Equal(55, sum);
}

private sealed class TrackFlagsAsyncEnumerable : IAsyncEnumerable<int>, IAsyncEnumerator<int>, IValueTaskSource<bool>, IValueTaskSource
{
public ValueTaskSourceOnCompletedFlags Flags;

public IAsyncEnumerator<int> GetAsyncEnumerator() => this;
public ValueTask<bool> MoveNextAsync() => new ValueTask<bool>(this, 0);
public int Current => throw new NotImplementedException();
public ValueTask DisposeAsync() => new ValueTask(this, 0);

public void OnCompleted(Action<object> continuation, object state, short token, ValueTaskSourceOnCompletedFlags flags) => Flags = flags;
public ValueTaskSourceStatus GetStatus(short token) => ValueTaskSourceStatus.Pending;
public bool GetResult(short token) => throw new NotImplementedException();
void IValueTaskSource.GetResult(short token) => throw new NotImplementedException();
}

private sealed class EnumerableWithDelayToAsyncEnumerable<T> : IAsyncEnumerable<T>, IAsyncEnumerator<T>
{
private readonly int _delayMs;
private readonly IEnumerable<T> _enumerable;
private IEnumerator<T> _enumerator;

public EnumerableWithDelayToAsyncEnumerable(IEnumerable<T> enumerable, int delayMs)
{
_enumerable = enumerable;
_delayMs = delayMs;
}

public IAsyncEnumerator<T> GetAsyncEnumerator()
{
_enumerator = _enumerable.GetEnumerator();
return this;
}

public async ValueTask<bool> MoveNextAsync()
{
await Task.Delay(_delayMs);
return _enumerator.MoveNext();
}

public T Current => _enumerator.Current;

public async ValueTask DisposeAsync()
{
await Task.Delay(_delayMs);
_enumerator.Dispose();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
<Compile Include="Task\TaskCanceledExceptionTests.netcoreapp.cs" />
<Compile Include="Task\TaskStatusTest.netcoreapp.cs" />
<Compile Include="System.Runtime.CompilerServices\AsyncTaskMethodBuilderTests.netcoreapp.cs" />
<Compile Include="System.Runtime.CompilerServices\ConfiguredAsyncEnumerableTests.netcoreapp.cs" />
<Compile Include="$(CommonTestPath)\System\Diagnostics\Tracing\TestEventListener.cs">
<Link>Common\System\Diagnostics\Tracing\TestEventListener.cs</Link>
</Compile>
Expand Down