Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nunit tests adapter losing async locals #16157

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ internal class AvaloniaTestMethodCommand : TestCommand
.GetField("BeforeTest", BindingFlags.Instance | BindingFlags.NonPublic)!;
private static FieldInfo s_afterTest = typeof(BeforeAndAfterTestCommand)
.GetField("AfterTest", BindingFlags.Instance | BindingFlags.NonPublic)!;

private AvaloniaTestMethodCommand(
HeadlessUnitTestSession session,
TestCommand innerCommand,
Expand All @@ -47,7 +47,7 @@ public static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestCo
{
return ProcessCommand(session, command, new List<Action>(), new List<Action>());
}

private static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestCommand command, List<Action> before, List<Action> after)
{
if (command is BeforeAndAfterTestCommand beforeAndAfterTestCommand)
Expand Down Expand Up @@ -79,7 +79,7 @@ private static TestCommand ProcessCommand(HeadlessUnitTestSession session, TestC

public override TestResult Execute(TestExecutionContext context)
{
return _session.Dispatch(() => ExecuteTestMethod(context), default).GetAwaiter().GetResult();
return _session.DispatchCore(() => ExecuteTestMethod(context), true, default).GetAwaiter().GetResult();
}

// Unfortunately, NUnit has issues with custom synchronization contexts, which means we need to add some hacks to make it work.
Expand Down
1 change: 1 addition & 0 deletions src/Headless/Avalonia.Headless/Avalonia.Headless.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
</ItemGroup>

<ItemGroup Label="InternalsVisibleTo">
<InternalsVisibleTo Include="Avalonia.Headless.NUnit, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.Headless.Vnc, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.UnitTests, PublicKey=$(AvaloniaPublicKey)" />
<InternalsVisibleTo Include="Avalonia.Base.UnitTests, PublicKey=$(AvaloniaPublicKey)" />
Expand Down
42 changes: 28 additions & 14 deletions src/Headless/Avalonia.Headless/HeadlessUnitTestSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public sealed class HeadlessUnitTestSession : IDisposable

private readonly AppBuilder _appBuilder;
private readonly CancellationTokenSource _cancellationTokenSource;
private readonly BlockingCollection<Action> _queue;
private readonly BlockingCollection<(Action, ExecutionContext?)> _queue;
private readonly Task _dispatchTask;

internal const DynamicallyAccessedMemberTypes DynamicallyAccessed =
Expand All @@ -32,50 +32,58 @@ public sealed class HeadlessUnitTestSession : IDisposable
DynamicallyAccessedMemberTypes.PublicParameterlessConstructor;

private HeadlessUnitTestSession(AppBuilder appBuilder, CancellationTokenSource cancellationTokenSource,
BlockingCollection<Action> queue, Task dispatchTask)
BlockingCollection<(Action, ExecutionContext?)> queue, Task dispatchTask)
{
_appBuilder = appBuilder;
_cancellationTokenSource = cancellationTokenSource;
_queue = queue;
_dispatchTask = dispatchTask;
}

/// <inheritdoc cref="Dispatch{TResult}(Func{Task{TResult}}, CancellationToken)"/>
/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task Dispatch(Action action, CancellationToken cancellationToken)
{
return Dispatch(() =>
return DispatchCore(() =>
{
action();
return Task.FromResult(0);
}, cancellationToken);
}, false ,cancellationToken);
}

/// <inheritdoc cref="Dispatch{TResult}(Func{Task{TResult}}, CancellationToken)"/>
/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task<TResult> Dispatch<TResult>(Func<TResult> action, CancellationToken cancellationToken)
{
return Dispatch(() => Task.FromResult(action()), cancellationToken);
return DispatchCore(() => Task.FromResult(action()), false, cancellationToken);
}

/// <inheritdoc cref="DispatchCore{TResult}"/>
public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationToken cancellationToken)
{
return DispatchCore(action, false, cancellationToken);
}

/// <summary>
/// Dispatch method queues an async operation on the dispatcher thread, creates a new application instance,
/// setting app avalonia services, and runs <paramref name="action"/> parameter.
/// </summary>
/// <param name="action">Action to execute on the dispatcher thread with avalonia services.</param>
/// <param name="captureExecutionContext">Whether dispatch should capture ExecutionContext.</param>
/// <param name="cancellationToken">Cancellation token to cancel execution.</param>
/// <exception cref="ObjectDisposedException">
/// If global session was already cancelled and thread killed, it's not possible to dispatch any actions again
/// </exception>
public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationToken cancellationToken)
internal Task<TResult> DispatchCore<TResult>(Func<Task<TResult>> action, bool captureExecutionContext, CancellationToken cancellationToken)
{
if (_cancellationTokenSource.IsCancellationRequested)
{
throw new ObjectDisposedException("Session was already disposed.");
}

var token = _cancellationTokenSource.Token;
var executionContext = captureExecutionContext ? ExecutionContext.Capture() : null;

var tcs = new TaskCompletionSource<TResult>();
_queue.Add(() =>
_queue.Add((() =>
{
var cts = new CancellationTokenSource();
using var globalCts = token.Register(s => ((CancellationTokenSource)s!).Cancel(), cts, true);
Expand All @@ -84,7 +92,6 @@ public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationT
try
{
using var application = EnsureApplication();

var task = action();
if (task.Status != TaskStatus.RanToCompletion)
{
Expand All @@ -110,7 +117,7 @@ public Task<TResult> Dispatch<TResult>(Func<Task<TResult>> action, CancellationT
{
tcs.TrySetException(ex);
}
});
}, executionContext));
return tcs.Task;
}

Expand Down Expand Up @@ -157,7 +164,7 @@ public static HeadlessUnitTestSession StartNew(
{
var tcs = new TaskCompletionSource<HeadlessUnitTestSession>();
var cancellationTokenSource = new CancellationTokenSource();
var queue = new BlockingCollection<Action>();
var queue = new BlockingCollection<(Action, ExecutionContext?)>();

Task? task = null;
task = Task.Run(() =>
Expand Down Expand Up @@ -185,8 +192,15 @@ public static HeadlessUnitTestSession StartNew(
{
try
{
var action = queue.Take(cancellationTokenSource.Token);
action();
var (action, executionContext) = queue.Take(cancellationTokenSource.Token);
if (executionContext is not null)
{
ExecutionContext.Run(executionContext, a => ((Action)a!).Invoke(), action);
}
else
{
action();
}
}
catch (OperationCanceledException)
{
Expand Down
35 changes: 29 additions & 6 deletions tests/Avalonia.Headless.UnitTests/ThreadingTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Avalonia.Threading;
Expand All @@ -14,6 +16,7 @@ public class ThreadingTests
#endif
public void Should_Be_On_Dispatcher_Thread()
{
ValidateTestContext();
Dispatcher.UIThread.VerifyAccess();
}

Expand All @@ -34,20 +37,40 @@ public void Should_Fail_Test_On_Delayed_Post_When_FlushDispatcher()
#endif
public async Task DispatcherTimer_Works_On_The_Same_Thread(int interval)
{
Assert.NotNull(SynchronizationContext.Current);
ValidateTestContext();
var currentThread = Thread.CurrentThread;

await Task.Delay(100);

var currentThread = Thread.CurrentThread;
ValidateTestContext();
Assert.True(currentThread == Thread.CurrentThread);

var tcs = new TaskCompletionSource();
var hasCompleted = false;

DispatcherTimer.RunOnce(() =>
{
hasCompleted = currentThread == Thread.CurrentThread;

tcs.SetResult();
try
{
ValidateTestContext();
Assert.True(currentThread == Thread.CurrentThread);
tcs.SetResult();
}
catch (Exception ex)
{
tcs.SetException(ex);
}
}, TimeSpan.FromTicks(interval));

await tcs.Task;
Assert.True(hasCompleted);
}

private void ValidateTestContext([CallerMemberName] string runningMethodName = null)
{
#if NUNIT
var testName = TestContext.CurrentContext.Test.Name;
// Test.Name also includes parameters.
Assert.AreEqual(testName.Split('(').First(), runningMethodName);
#endif
}
}
Loading