diff --git a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs index 65f483cc9a8a..bce9c3fe827a 100644 --- a/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs +++ b/dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs @@ -1,5 +1,6 @@ // Copyright (c) Microsoft. All rights reserved. using System; +using System.ClientModel; using System.Collections.Generic; using System.Linq; using System.Net; @@ -179,9 +180,11 @@ public static async IAsyncEnumerable GetMessagesAsync(Assist // Evaluate status and process steps and messages, as encountered. HashSet processedStepIds = []; Dictionary functionSteps = []; - do { + // Check for cancellation + cancellationToken.ThrowIfCancellationRequested(); + // Poll run and steps until actionable await PollRunStatusAsync().ConfigureAwait(false); @@ -301,20 +304,49 @@ async Task PollRunStatusAsync() do { - // Reduce polling frequency after a couple attempts - await Task.Delay(agent.PollingOptions.GetPollingInterval(count), cancellationToken).ConfigureAwait(false); + cancellationToken.ThrowIfCancellationRequested(); + + if (count > 0) + { + // Reduce polling frequency after a couple attempts + await Task.Delay(agent.PollingOptions.GetPollingInterval(count), cancellationToken).ConfigureAwait(false); + } + ++count; -#pragma warning disable CA1031 // Do not catch general exception types try { run = await client.GetRunAsync(threadId, run.Id, cancellationToken).ConfigureAwait(false); } - catch + // The presence of a `Status` code means the server responded with error...always fail in that case + catch (ClientResultException clientException) when (clientException.Status <= 0) + { + // Check maximum retry count + if (count >= agent.PollingOptions.MaximumRetryCount) + { + throw; + } + + // Retry for potential transient failure + continue; + } + catch (AggregateException aggregateException) when (aggregateException.InnerException is ClientResultException innerClientException) { - // Retry anyway.. + // The presence of a `Status` code means the server responded with error + if (innerClientException.Status > 0) + { + throw; + } + + // Check maximum retry count + if (count >= agent.PollingOptions.MaximumRetryCount) + { + throw; + } + + // Retry for potential transient failure + continue; } -#pragma warning restore CA1031 // Do not catch general exception types } while (s_pollingStatuses.Contains(run.Status)); @@ -373,6 +405,9 @@ public static async IAsyncEnumerable InvokeStreamin IAsyncEnumerable asyncUpdates = client.CreateRunStreamingAsync(threadId, agent.Id, options, cancellationToken); do { + // Check for cancellation + cancellationToken.ThrowIfCancellationRequested(); + stepsToProcess.Clear(); await foreach (StreamingUpdate update in asyncUpdates.ConfigureAwait(false)) diff --git a/dotnet/src/Agents/OpenAI/RunPollingOptions.cs b/dotnet/src/Agents/OpenAI/RunPollingOptions.cs index 756ba689131c..b108048f32d3 100644 --- a/dotnet/src/Agents/OpenAI/RunPollingOptions.cs +++ b/dotnet/src/Agents/OpenAI/RunPollingOptions.cs @@ -8,6 +8,11 @@ namespace Microsoft.SemanticKernel.Agents.OpenAI; /// public sealed class RunPollingOptions { + /// + /// The default maximum number or retries when monitoring thread-run status. + /// + public static int DefaultMaximumRetryCount { get; } = 3; + /// /// The default polling interval when monitoring thread-run status. /// @@ -28,6 +33,15 @@ public sealed class RunPollingOptions /// public static TimeSpan DefaultMessageSynchronizationDelay { get; } = TimeSpan.FromMilliseconds(500); + /// + /// The maximum retry count when polling thread-run status. + /// + /// + /// Only affects failures that have the potential to be transient. Explicit server error responses + /// will result in immediate failure. + /// + public int MaximumRetryCount { get; set; } = DefaultMaximumRetryCount; + /// /// The polling interval when monitoring thread-run status. ///