Skip to content

Commit

Permalink
Reimplement CancellationTokenSourcePool (#1192)
Browse files Browse the repository at this point in the history
  • Loading branch information
martintmk authored May 17, 2023
1 parent eca8ac2 commit a4343b7
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ private HedgingExecutionContext Create()
var pool = new ObjectPool<TaskExecution>(
() =>
{
var execution = new TaskExecution(handler);
var execution = new TaskExecution(handler, CancellationTokenSourcePool.Create(_timeProvider));
_createdExecutions.Add(execution);
return execution;
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Polly.Hedging.Controller;
using Polly.Hedging.Utils;
using Polly.Strategy;
using Polly.Utils;

namespace Polly.Core.Tests.Hedging.Controller;

Expand Down Expand Up @@ -265,5 +266,5 @@ private void CreateSnapshot(CancellationToken? token = null)

private Func<HedgingActionGeneratorArguments<DisposableResult>, Func<Task<DisposableResult>>?> Generator { get; set; } = args => () => Task.FromResult(new DisposableResult { Name = Handled });

private TaskExecution Create() => new(_hedgingHandler.CreateHandler()!);
private TaskExecution Create() => new(_hedgingHandler.CreateHandler()!, CancellationTokenSourcePool.Create(TimeProvider.System));
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ public async Task ExecuteAsync_ShouldReturnAnyPossibleResult()
public async void ExecuteAsync_EnsureHedgedTasksCancelled_Ok()
{
// arrange
_testOutput.WriteLine("ExecuteAsync_EnsureHedgedTasksCancelled_Ok executing...");

_options.MaxHedgedAttempts = 2;
using var cancelled = new ManualResetEvent(false);
ConfigureHedging(async context =>
Expand Down Expand Up @@ -155,6 +157,11 @@ public async void ExecuteAsync_EnsureHedgedTasksCancelled_Ok()
});

// assert
_timeProvider.Advance(_options.HedgingDelay);
await Task.Delay(20);
_timeProvider.Advance(_options.HedgingDelay);
await Task.Delay(20);

_timeProvider.Advance(TimeSpan.FromHours(1));
(await result).Should().Be(Success);
cancelled.WaitOne(AssertTimeout).Should().BeTrue();
Expand Down
103 changes: 93 additions & 10 deletions src/Polly.Core.Tests/Utils/CancellationTokenSourcePoolTests.cs
Original file line number Diff line number Diff line change
@@ -1,33 +1,116 @@
using System;
using System.Threading;
using Moq;
using Polly.Core.Tests.Helpers;
using Polly.Utils;
using Xunit;

namespace Polly.Core.Tests.Utils;

public class CancellationTokenSourcePoolTests
{
public static IEnumerable<object[]> TimeProviders()
{
yield return new object[] { TimeProvider.System };
yield return new object[] { new FakeTimeProvider() };
}

[Fact]
public void RentReturn_Reusable_EnsureProperBehavior()
public void ArgValidation_Ok()
{
var pool = CancellationTokenSourcePool.Create(TimeProvider.System);

Assert.Throws<ArgumentOutOfRangeException>(() => pool.Get(TimeSpan.Zero));
var e = Assert.Throws<ArgumentOutOfRangeException>(() => pool.Get(TimeSpan.FromMilliseconds(-2)));
e.Message.Should().StartWith("Invalid delay specified.");
e.ActualValue.Should().Be(TimeSpan.FromMilliseconds(-2));

pool.Get(System.Threading.Timeout.InfiniteTimeSpan).Should().NotBeNull();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public void RentReturn_Reusable_EnsureProperBehavior(object timeProvider)
{
var cts = CancellationTokenSourcePool.Get();
CancellationTokenSourcePool.Return(cts);
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
pool.Return(cts);

var cts2 = CancellationTokenSourcePool.Get();
var cts2 = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
#if NET6_0_OR_GREATER
cts2.Should().BeSameAs(cts);
if (timeProvider == TimeProvider.System)
{
cts2.Should().BeSameAs(cts);
}
else
{
cts2.Should().NotBeSameAs(cts);
}
#else
cts2.Should().NotBeSameAs(cts);
#endif
}

[Fact]
public void RentReturn_NotReusable_EnsureProperBehavior()
[MemberData(nameof(TimeProviders))]
[Theory]
public void RentReturn_NotReusable_EnsureProperBehavior(object timeProvider)
{
var cts = CancellationTokenSourcePool.Get();
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
cts.Cancel();
CancellationTokenSourcePool.Return(cts);
pool.Return(cts);

cts.Invoking(c => c.Token).Should().Throw<ObjectDisposedException>();

var cts2 = CancellationTokenSourcePool.Get();
var cts2 = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);
cts2.Token.Should().NotBeNull();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public async Task Rent_Cancellable_EnsureCancelled(object timeProvider)
{
if (timeProvider is Mock<TimeProvider> fakeTimeProvider)
{
fakeTimeProvider
.Setup(v => v.CancelAfter(It.IsAny<CancellationTokenSource>(), TimeSpan.FromMilliseconds(1)))
.Callback<CancellationTokenSource, TimeSpan>((source, _) => source.Cancel());
}
else
{
fakeTimeProvider = null!;
}

var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(TimeSpan.FromMilliseconds(1));

await Task.Delay(100);

cts.IsCancellationRequested.Should().BeTrue();
fakeTimeProvider?.VerifyAll();
}

[MemberData(nameof(TimeProviders))]
[Theory]
public async Task Rent_NotCancellable_EnsureNotCancelled(object timeProvider)
{
var pool = CancellationTokenSourcePool.Create(GetTimeProvider(timeProvider));
var cts = pool.Get(System.Threading.Timeout.InfiniteTimeSpan);

await Task.Delay(20);

cts.IsCancellationRequested.Should().BeFalse();

if (timeProvider is Mock<TimeProvider> fakeTimeProvider)
{
fakeTimeProvider
.Verify(v => v.CancelAfter(It.IsAny<CancellationTokenSource>(), It.IsAny<TimeSpan>()), Times.Never());
}
}

private static TimeProvider GetTimeProvider(object timeProvider) => timeProvider switch
{
Mock<TimeProvider> m => m.Object,
_ => (TimeProvider)timeProvider
};
}
5 changes: 4 additions & 1 deletion src/Polly.Core/Hedging/Controller/HedgingController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ internal sealed class HedgingController

public HedgingController(TimeProvider provider, HedgingHandler.Handler handler, int maxAttempts)
{
// retrieve the cancellation pool for this time provider
var pool = CancellationTokenSourcePool.Create(provider);

_executionPool = new ObjectPool<TaskExecution>(() =>
{
Interlocked.Increment(ref _rentedExecutions);
return new TaskExecution(handler);
return new TaskExecution(handler, pool);
},
_ =>
{
Expand Down
11 changes: 8 additions & 3 deletions src/Polly.Core/Hedging/Controller/TaskExecution.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ internal sealed class TaskExecution
{
private readonly ResilienceContext _cachedContext = ResilienceContext.Get();
private readonly HedgingHandler.Handler _handler;
private readonly CancellationTokenSourcePool _cancellationTokenSourcePool;
private CancellationTokenSource? _cancellationSource;
private CancellationTokenRegistration? _cancellationRegistration;
private ResilienceContext? _activeContext;

public TaskExecution(HedgingHandler.Handler handler) => _handler = handler;
public TaskExecution(HedgingHandler.Handler handler, CancellationTokenSourcePool cancellationTokenSourcePool)
{
_handler = handler;
_cancellationTokenSourcePool = cancellationTokenSourcePool;
}

/// <summary>
/// Gets the task that represents the execution of the hedged task.
Expand Down Expand Up @@ -81,7 +86,7 @@ public async ValueTask<bool> InitializeAsync<TResult, TState>(
int attempt)
{
Type = type;
_cancellationSource = CancellationTokenSourcePool.Get();
_cancellationSource = _cancellationTokenSourcePool.Get(System.Threading.Timeout.InfiniteTimeSpan);
Properties.Replace(snapshot.OriginalProperties);

if (snapshot.OriginalCancellationToken.CanBeCanceled)
Expand Down Expand Up @@ -147,7 +152,7 @@ public async ValueTask ResetAsync()
{
// accepted outcome means that the cancellation source can be be returned to the pool
// since it was most likely not cancelled
CancellationTokenSourcePool.Return(_cancellationSource!);
_cancellationTokenSourcePool.Return(_cancellationSource!);
}

IsAccepted = false;
Expand Down
9 changes: 4 additions & 5 deletions src/Polly.Core/Timeout/TimeoutResilienceStrategy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ namespace Polly.Timeout;

internal sealed class TimeoutResilienceStrategy : ResilienceStrategy
{
private readonly TimeProvider _timeProvider;
private readonly ResilienceStrategyTelemetry _telemetry;
private readonly CancellationTokenSourcePool _cancellationTokenSourcePool;

public TimeoutResilienceStrategy(TimeoutStrategyOptions options, TimeProvider timeProvider, ResilienceStrategyTelemetry telemetry)
{
DefaultTimeout = options.Timeout;
TimeoutGenerator = options.TimeoutGenerator.CreateHandler(DefaultTimeout, TimeoutUtil.IsTimeoutValid);
OnTimeout = options.OnTimeout.CreateHandler();
_timeProvider = timeProvider;
_telemetry = telemetry;
_cancellationTokenSourcePool = CancellationTokenSourcePool.Create(timeProvider);
}

public TimeSpan DefaultTimeout { get; }
Expand All @@ -35,8 +35,7 @@ protected internal override async ValueTask<TResult> ExecuteCoreAsync<TResult, T
}

var previousToken = context.CancellationToken;
var cancellationSource = CancellationTokenSourcePool.Get();
_timeProvider.CancelAfter(cancellationSource, timeout);
var cancellationSource = _cancellationTokenSourcePool.Get(timeout);
context.CancellationToken = cancellationSource.Token;

CancellationTokenRegistration? registration = null;
Expand Down Expand Up @@ -76,7 +75,7 @@ protected internal override async ValueTask<TResult> ExecuteCoreAsync<TResult, T
finally
{
context.CancellationToken = previousToken;
CancellationTokenSourcePool.Return(cancellationSource);
_cancellationTokenSourcePool.Return(cancellationSource);
}
}

Expand Down
25 changes: 25 additions & 0 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.Disposable.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
namespace Polly.Utils;

internal abstract partial class CancellationTokenSourcePool
{
private sealed class DisposableCancellationTokenSourcePool : CancellationTokenSourcePool
{
private readonly TimeProvider _timeProvider;

public DisposableCancellationTokenSourcePool(TimeProvider timeProvider) => _timeProvider = timeProvider;

protected override CancellationTokenSource GetCore(TimeSpan delay)
{
var source = new CancellationTokenSource();

if (IsCancellable(delay))
{
_timeProvider.CancelAfter(source, delay);
}

return source;
}

public override void Return(CancellationTokenSource source) => source.Dispose();
}
}
42 changes: 42 additions & 0 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.Pooled.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
namespace Polly.Utils;

internal abstract partial class CancellationTokenSourcePool
{
#if NET6_0_OR_GREATER
private sealed class PooledCancellationTokenSourcePool : CancellationTokenSourcePool
{
public static readonly PooledCancellationTokenSourcePool SystemInstance = new(TimeProvider.System);

public PooledCancellationTokenSourcePool(TimeProvider timeProvider) => _timeProvider = timeProvider;

private readonly ObjectPool<CancellationTokenSource> _pool = new(
static () => new CancellationTokenSource(),
static cts => true);
private readonly TimeProvider _timeProvider;

protected override CancellationTokenSource GetCore(TimeSpan delay)
{
var source = _pool.Get();

if (IsCancellable(delay))
{
_timeProvider.CancelAfter(source, delay);
}

return source;
}

public override void Return(CancellationTokenSource source)
{
if (source.TryReset())
{
_pool.Return(source);
}
else
{
source.Dispose();
}
}
}
#endif
}
44 changes: 23 additions & 21 deletions src/Polly.Core/Utils/CancellationTokenSourcePool.cs
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
namespace Polly.Utils
namespace Polly.Utils;

#pragma warning disable CA1716 // Identifiers should not match keywords

internal abstract partial class CancellationTokenSourcePool
{
internal static class CancellationTokenSourcePool
public static CancellationTokenSourcePool Create(TimeProvider timeProvider)
{
#if NET6_0_OR_GREATER
private static readonly ObjectPool<CancellationTokenSource> Pool = new(
static () => new CancellationTokenSource(),
static cts => true);
#endif
public static CancellationTokenSource Get()
if (timeProvider == TimeProvider.System)
{
#if NET6_0_OR_GREATER
return Pool.Get();
#else
return new CancellationTokenSource();
#endif
return PooledCancellationTokenSourcePool.SystemInstance;
}
#endif
return new DisposableCancellationTokenSourcePool(timeProvider);
}

public static void Return(CancellationTokenSource source)
public CancellationTokenSource Get(TimeSpan delay)
{
if (delay <= TimeSpan.Zero && delay != System.Threading.Timeout.InfiniteTimeSpan)
{
#if NET6_0_OR_GREATER
if (source.TryReset())
{
Pool.Return(source);
return;
}
#endif
source.Dispose();
throw new ArgumentOutOfRangeException(nameof(delay), delay, "Invalid delay specified.");
}

return GetCore(delay);
}

protected abstract CancellationTokenSource GetCore(TimeSpan delay);

public abstract void Return(CancellationTokenSource source);

protected static bool IsCancellable(TimeSpan delay) => delay != System.Threading.Timeout.InfiniteTimeSpan;
}

0 comments on commit a4343b7

Please sign in to comment.