Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement CancellationTokenSourcePool #1192

Merged
merged 4 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ private HedgingExecutionContext Create()
var pool = new ObjectPool<TaskExecution>(
() =>
{
var execution = new TaskExecution(handler);
var execution = new TaskExecution(handler, CancellationTokenSourcePool.Create(_timeProvider));
_createdExecutions.Add(execution);
return execution;
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Polly.Hedging.Controller;
using Polly.Hedging.Utils;
using Polly.Strategy;
using Polly.Utils;

namespace Polly.Core.Tests.Hedging.Controller;

Expand Down Expand Up @@ -265,5 +266,5 @@ private void CreateSnapshot(CancellationToken? token = null)

private Func<HedgingActionGeneratorArguments<DisposableResult>, Func<Task<DisposableResult>>?> Generator { get; set; } = args => () => Task.FromResult(new DisposableResult { Name = Handled });

private TaskExecution Create() => new(_hedgingHandler.CreateHandler()!);
private TaskExecution Create() => new(_hedgingHandler.CreateHandler()!, CancellationTokenSourcePool.Create(TimeProvider.System));
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ public async Task ExecuteAsync_ShouldReturnAnyPossibleResult()
public async void ExecuteAsync_EnsureHedgedTasksCancelled_Ok()
{
// arrange
_testOutput.WriteLine("ExecuteAsync_EnsureHedgedTasksCancelled_Ok executing...");

_options.MaxHedgedAttempts = 2;
using var cancelled = new ManualResetEvent(false);
ConfigureHedging(async context =>
Expand Down Expand Up @@ -155,6 +157,11 @@ public async void ExecuteAsync_EnsureHedgedTasksCancelled_Ok()
});

// assert
_timeProvider.Advance(_options.HedgingDelay);
await Task.Delay(20);
_timeProvider.Advance(_options.HedgingDelay);
await Task.Delay(20);

_timeProvider.Advance(TimeSpan.FromHours(1));
(await result).Should().Be(Success);
cancelled.WaitOne(AssertTimeout).Should().BeTrue();
Expand Down
103 changes: 93 additions & 10 deletions src/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs
Original file line number Diff line number Diff line change
@@ -1,33 +1,116 @@
using System;
using System.Threading;
using Moq;
using Polly.Core.Tests.Helpers;
using Polly.Utils;
using Xunit;

namespace Polly.Core.Tests.Utils;

public class CancellationTokenSourcePoolTests
{
public static IEnumerable<object[]> TimeProviders()
{
yield return new object[] { TimeProvider.System };
yield return new object[] { new FakeTimeProvider() };
}

[Fact]
public void RentReturn_Reusable_EnsureProperBehavior()
public void ArgValidation_Ok()
{
var pool = CancellationTokenSourcePool.Create(TimeProvider.System);

Assert.Throws<ArgumentOutOfRangeException>(() => pool.Get(TimeSpan.Zero));
var e = Assert.Throws<ArgumentOutOfRangeException>(() => pool.Get(TimeSpan.FromMilliseconds(-2)));
e.Message.Should().StartWith("Invalid delay specified.");
e.ActualValue.Should().Be(TimeSpan.FromMilliseconds(-2));

pool.Get(System.Threading.Timeout.InfiniteTimeSpan).Should().NotBeNull();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public void RentReturn_Reusable_EnsureProperBehavior(object timeProvider)
{
var cts = CancellationTokenSourcePool.Get();
CancellationTokenSourcePool.Return(cts);
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
pool.Return(cts);

var cts2 = CancellationTokenSourcePool.Get();
var cts2 = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
#if NET6_0_OR_GREATER
cts2.Should().BeSameAs(cts);
if (timeProvider == TimeProvider.System)
{
cts2.Should().BeSameAs(cts);
}
else
{
cts2.Should().NotBeSameAs(cts);
}
#else
cts2.Should().NotBeSameAs(cts);
#endif
}

[Fact]
public void RentReturn_NotReusable_EnsureProperBehavior()
[MemberData(nameof(TimeProviders))]
[Theory]
public void RentReturn_NotReusable_EnsureProperBehavior(object timeProvider)
{
var cts = CancellationTokenSourcePool.Get();
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
cts.Cancel();
CancellationTokenSourcePool.Return(cts);
pool.Return(cts);

cts.Invoking(c => c.Token).Should().Throw<ObjectDisposedException>();

var cts2 = CancellationTokenSourcePool.Get();
var cts2 = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
cts2.Token.Should().NotBeNull();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public async Task Rent_Cancellable_EnsureCancelled(object timeProvider)
{
if (timeProvider is Mock<TimeProvider> fakeTimeProvider)
{
fakeTimeProvider
.Setup(v => v.CancelAfter(It.IsAny<CancellationTokenSource>(), TimeSpan.FromMilliseconds(1)))
.Callback<CancellationTokenSource, TimeSpan>((source, _) => source.Cancel());
}
else
{
fakeTimeProvider = null!;
}

var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(TimeSpan.FromMilliseconds(1));

await Task.Delay(100);

cts.IsCancellationRequested.Should().BeTrue();
fakeTimeProvider?.VerifyAll();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public async Task Rent_NotCancellable_EnsureNotCancelled(object timeProvider)
{
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);

await Task.Delay(20);

cts.IsCancellationRequested.Should().BeFalse();

if (timeProvider is Mock<TimeProvider> fakeTimeProvider)
{
fakeTimeProvider
.Verify(v => v.CancelAfter(It.IsAny<CancellationTokenSource>(), It.IsAny<TimeSpan>()), Times.Never());
}
}

private static TimeProvider GetTimeProvider(object timeProvider) => timeProvider switch
{
Mock<TimeProvider> m => m.Object,
_ => (TimeProvider)timeProvider
};
}
5 changes: 4 additions & 1 deletion src/Polly.Core/Hedging/Controller/HedgingController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ internal sealed class HedgingController

public HedgingController(TimeProvider provider, HedgingHandler.Handler handler, int maxAttempts)
{
// retrieve the cancellation pool for this time provider
var pool = CancellationTokenSourcePool.Create(provider);

_executionPool = new ObjectPool<TaskExecution>(() =>
{
Interlocked.Increment(ref _rentedExecutions);
return new TaskExecution(handler);
return new TaskExecution(handler, pool);
},
_ =>
{
Expand Down
11 changes: 8 additions & 3 deletions src/Polly.Core/Hedging/Controller/TaskExecution.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ internal sealed class TaskExecution
{
private readonly ResilienceContext _cachedContext = ResilienceContext.Get();
private readonly HedgingHandler.Handler _handler;
private readonly CancellationTokenSourcePool _cancellationTokenSourcePool;
private CancellationTokenSource? _cancellationSource;
private CancellationTokenRegistration? _cancellationRegistration;
private ResilienceContext? _activeContext;

public TaskExecution(HedgingHandler.Handler handler) => _handler = handler;
public TaskExecution(HedgingHandler.Handler handler, CancellationTokenSourcePool cancellationTokenSourcePool)
{
_handler = handler;
_cancellationTokenSourcePool = cancellationTokenSourcePool;
}

/// <summary>
/// Gets the task that represents the execution of the hedged task.
Expand Down Expand Up @@ -81,7 +86,7 @@ public async ValueTask<bool> InitializeAsync<TResult, TState>(
int attempt)
{
Type = type;
_cancellationSource = CancellationTokenSourcePool.Get();
_cancellationSource = _cancellationTokenSourcePool.Get(System.Threading.Timeout.InfiniteTimeSpan);
Properties.Replace(snapshot.OriginalProperties);

if (snapshot.OriginalCancellationToken.CanBeCanceled)
Expand Down Expand Up @@ -147,7 +152,7 @@ public async ValueTask ResetAsync()
{
// accepted outcome means that the cancellation source can be be returned to the pool
// since it was most likely not cancelled
CancellationTokenSourcePool.Return(_cancellationSource!);
_cancellationTokenSourcePool.Return(_cancellationSource!);
}

IsAccepted = false;
Expand Down
9 changes: 4 additions & 5 deletions src/Polly.Core/Timeout/TimeoutResilienceStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ namespace Polly.Timeout;

internal sealed class TimeoutResilienceStrategy : ResilienceStrategy
{
private readonly TimeProvider _timeProvider;
private readonly ResilienceStrategyTelemetry _telemetry;
private readonly CancellationTokenSourcePool _cancellationTokenSourcePool;

public TimeoutResilienceStrategy(TimeoutStrategyOptions options, TimeProvider timeProvider, ResilienceStrategyTelemetry telemetry)
{
DefaultTimeout = options.Timeout;
TimeoutGenerator = options.TimeoutGenerator.CreateHandler(DefaultTimeout, TimeoutUtil.IsTimeoutValid);
OnTimeout = options.OnTimeout.CreateHandler();
_timeProvider = timeProvider;
_telemetry = telemetry;
_cancellationTokenSourcePool = CancellationTokenSourcePool.Create(timeProvider);
}

public TimeSpan DefaultTimeout { get; }
Expand All @@ -35,8 +35,7 @@ protected internal override async ValueTask<TResult> ExecuteCoreAsync<TResult, T
}

var previousToken = context.CancellationToken;
var cancellationSource = CancellationTokenSourcePool.Get();
_timeProvider.CancelAfter(cancellationSource, timeout);
var cancellationSource = _cancellationTokenSourcePool.Get(timeout);
context.CancellationToken = cancellationSource.Token;

CancellationTokenRegistration? registration = null;
Expand Down Expand Up @@ -76,7 +75,7 @@ protected internal override async ValueTask<TResult> ExecuteCoreAsync<TResult, T
finally
{
context.CancellationToken = previousToken;
CancellationTokenSourcePool.Return(cancellationSource);
_cancellationTokenSourcePool.Return(cancellationSource);
}
}

Expand Down
25 changes: 25 additions & 0 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace Polly.Utils;

internal abstract partial class CancellationTokenSourcePool
{
private sealed class DisposableCancellationTokenSourcePool : CancellationTokenSourcePool
{
private readonly TimeProvider _timeProvider;

public DisposableCancellationTokenSourcePool(TimeProvider timeProvider) => _timeProvider = timeProvider;

protected override CancellationTokenSource GetCore(TimeSpan delay)
{
var source = new CancellationTokenSource();

if (IsCancellable(delay))
{
_timeProvider.CancelAfter(source, delay);
}

return source;
}

public override void Return(CancellationTokenSource source) => source.Dispose();
}
}
42 changes: 42 additions & 0 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
namespace Polly.Utils;

internal abstract partial class CancellationTokenSourcePool
{
#if NET6_0_OR_GREATER
private sealed class PooledCancellationTokenSourcePool : CancellationTokenSourcePool
{
public static readonly PooledCancellationTokenSourcePool SystemInstance = new(TimeProvider.System);
martincostello marked this conversation as resolved.
Show resolved Hide resolved

public PooledCancellationTokenSourcePool(TimeProvider timeProvider) => _timeProvider = timeProvider;

private readonly ObjectPool<CancellationTokenSource> _pool = new(
static () => new CancellationTokenSource(),
static cts => true);
private readonly TimeProvider _timeProvider;

protected override CancellationTokenSource GetCore(TimeSpan delay)
{
var source = _pool.Get();

if (IsCancellable(delay))
{
_timeProvider.CancelAfter(source, delay);
}

return source;
}

public override void Return(CancellationTokenSource source)
{
if (source.TryReset())
{
_pool.Return(source);
}
else
{
source.Dispose();
}
}
}
#endif
}
44 changes: 23 additions & 21 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
namespace Polly.Utils
namespace Polly.Utils;

#pragma warning disable CA1716 // Identifiers should not match keywords

internal abstract partial class CancellationTokenSourcePool
{
internal static class CancellationTokenSourcePool
public static CancellationTokenSourcePool Create(TimeProvider timeProvider)
{
#if NET6_0_OR_GREATER
private static readonly ObjectPool<CancellationTokenSource> Pool = new(
static () => new CancellationTokenSource(),
static cts => true);
#endif
public static CancellationTokenSource Get()
if (timeProvider == TimeProvider.System)
{
#if NET6_0_OR_GREATER
return Pool.Get();
#else
return new CancellationTokenSource();
#endif
return PooledCancellationTokenSourcePool.SystemInstance;
}
#endif
return new DisposableCancellationTokenSourcePool(timeProvider);
}

public static void Return(CancellationTokenSource source)
public CancellationTokenSource Get(TimeSpan delay)
{
if (delay <= TimeSpan.Zero && delay != System.Threading.Timeout.InfiniteTimeSpan)
{
#if NET6_0_OR_GREATER
if (source.TryReset())
{
Pool.Return(source);
return;
}
#endif
source.Dispose();
throw new ArgumentOutOfRangeException(nameof(delay), delay, "Invalid delay specified.");
}

return GetCore(delay);
}

protected abstract CancellationTokenSource GetCore(TimeSpan delay);

public abstract void Return(CancellationTokenSource source);

protected static bool IsCancellable(TimeSpan delay) => delay != System.Threading.Timeout.InfiniteTimeSpan;
}