Skip to content

Commit

Permalink
Merge branch 'main' into update-azure-openai-packages
Browse files Browse the repository at this point in the history
  • Loading branch information
dmytrostruk authored Nov 11, 2024
2 parents 9d5c141 + 83a59d4 commit 0c384a9
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions dotnet/src/Agents/OpenAI/Internal/AssistantThreadActions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public static async IAsyncEnumerable<ChatMessageContent> GetMessagesAsync(Assist
throw new KernelException($"Agent Failure - Run terminated: {run.Status} [{run.Id}]: {run.LastError?.Message ?? "Unknown"}");
}

IReadOnlyList<RunStep> steps = await GetRunStepsAsync(client, run, cancellationToken).ConfigureAwait(false);
RunStep[] steps = await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Is tool action required?
if (run.Status == RunStatus.RequiresAction)
Expand Down Expand Up @@ -475,11 +475,14 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin

if (run.Status == RunStatus.RequiresAction)
{
IReadOnlyList<RunStep> steps = await GetRunStepsAsync(client, run, cancellationToken).ConfigureAwait(false);
RunStep[] activeSteps =
await client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken)
.Where(step => step.Status == RunStepStatus.InProgress)
.ToArrayAsync(cancellationToken).ConfigureAwait(false);

// Capture map between the tool call and its associated step
Dictionary<string, string> toolMap = [];
foreach (RunStep step in steps)
foreach (RunStep step in activeSteps)
{
foreach (RunStepToolCall stepDetails in step.Details.ToolCalls)
{
Expand All @@ -488,7 +491,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
}

// Execute functions in parallel and post results at once.
FunctionCallContent[] functionCalls = steps.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
FunctionCallContent[] functionCalls = activeSteps.SelectMany(step => ParseFunctionStep(agent, step)).ToArray();
if (functionCalls.Length > 0)
{
// Emit function-call content
Expand All @@ -504,7 +507,7 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamin
ToolOutput[] toolOutputs = GenerateToolOutputs(functionResults);
asyncUpdates = client.SubmitToolOutputsToRunStreamingAsync(run.ThreadId, run.Id, toolOutputs, cancellationToken);

foreach (RunStep step in steps)
foreach (RunStep step in activeSteps)
{
stepFunctionResults.Add(step.Id, functionResults.Where(result => step.Id == toolMap[result.CallId!]).ToArray());
}
Expand Down Expand Up @@ -560,18 +563,6 @@ await RetrieveMessageAsync(
logger.LogOpenAIAssistantCompletedRun(nameof(InvokeAsync), run?.Id ?? "Failed", threadId);
}

private static async Task<IReadOnlyList<RunStep>> GetRunStepsAsync(AssistantClient client, ThreadRun run, CancellationToken cancellationToken)
{
List<RunStep> steps = [];

await foreach (RunStep step in client.GetRunStepsAsync(run.ThreadId, run.Id, cancellationToken: cancellationToken).ConfigureAwait(false))
{
steps.Add(step);
}

return steps;
}

private static ChatMessageContent GenerateMessageContent(string? assistantName, ThreadMessage message, RunStep? completedStep = null)
{
AuthorRole role = new(message.Role.ToString());
Expand Down Expand Up @@ -788,7 +779,7 @@ private static ChatMessageContent GenerateFunctionCallContent(string agentName,
return functionCallContent;
}

private static ChatMessageContent GenerateFunctionResultContent(string agentName, FunctionResultContent[] functionResults, RunStep completedStep)
private static ChatMessageContent GenerateFunctionResultContent(string agentName, IEnumerable<FunctionResultContent> functionResults, RunStep completedStep)
{
ChatMessageContent functionResultContent = new(AuthorRole.Tool, content: null)
{
Expand Down

0 comments on commit 0c384a9

Please sign in to comment.