Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved Assistant streaming methods #600

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 127 additions & 10 deletions OpenAI.Playground/TestHelpers/AssistantHelpers/RunTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,58 @@ public static async Task CreateRunAsStreamTest(IOpenAIService openAI)
var result = openAI.Beta.Runs.RunCreateAsStream(CreatedThreadId, new()
{
AssistantId = assistantResult.Id
});
},justDataMode:false);

await foreach (var run in result)
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent!=null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down Expand Up @@ -450,13 +489,52 @@ public static async Task SubmitToolOutputsAsStreamToRunTest(IOpenAIService openA
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent != null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down Expand Up @@ -642,13 +720,52 @@ public static async Task CreateThreadAndRunAsStream(IOpenAIService sdk)
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent != null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down
22 changes: 22 additions & 0 deletions OpenAI.SDK/Extensions/JsonToObjectRouterExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Text.Json;
using OpenAI.ObjectModels.ResponseModels;
using OpenAI.ObjectModels.SharedModels;

namespace OpenAI.Extensions;

public static class JsonToObjectRouterExtension
{
public static Type Route(string json)
{
var apiResponse = JsonSerializer.Deserialize<ObjectBaseResponse>(json);

return apiResponse?.ObjectTypeName switch
{
"thread.run.step" => typeof(RunStepResponse),
"thread.run" => typeof(RunResponse),
"thread.message" => typeof(MessageResponse),
"thread.message.delta" => typeof(MessageResponse),
_ => typeof(BaseResponse)
};
}
}
39 changes: 36 additions & 3 deletions OpenAI.SDK/Extensions/StreamHandleExtension.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.CompilerServices;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text.Json;
using OpenAI.ObjectModels;
using OpenAI.ObjectModels.RequestModels;
Expand All @@ -8,6 +9,10 @@ namespace OpenAI.Extensions;

public static class StreamHandleExtension
{
public static async IAsyncEnumerable<BaseResponse> AsStream(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var baseResponse in AsStream<BaseResponse>(response, justDataMode, cancellationToken)) yield return baseResponse;
}
public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
where TResponse : BaseResponse, new()
{
Expand All @@ -20,13 +25,15 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes

await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);

string? tempStreamEvent = null;
bool isEventDelta;
// Continuously read the stream until the end of it
while (true)
{
cancellationToken.ThrowIfCancellationRequested();

var line = await reader.ReadLineAsync();
// Console.WriteLine("---" + line);
// Break the loop if we have reached the end of the stream
if (line == null)
{
Expand All @@ -39,11 +46,28 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
continue;
}

if (line.StartsWith("event: "))
{
line = line.RemoveIfStartWith("event: ");
tempStreamEvent = line;
isEventDelta = true;
}
else
{
isEventDelta = false;
}

if (justDataMode && !line.StartsWith("data: "))
{
continue;
}

if (!justDataMode && isEventDelta )
{
yield return new(){ObjectTypeName = "base.stream.event",StreamEvent = tempStreamEvent};
continue;
}

line = line.RemoveIfStartWith("data: ");

// Exit the loop if the stream is done
Expand All @@ -56,7 +80,14 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
try
{
// When the response is good, each line is a serializable CompletionCreateRequest
block = JsonSerializer.Deserialize<TResponse>(line);
if (typeof(TResponse) == typeof(BaseResponse))
{
block =JsonSerializer.Deserialize(line, JsonToObjectRouterExtension.Route(line), new JsonSerializerOptions()) as TResponse;
}
else
{
block = JsonSerializer.Deserialize<TResponse>(line);
}
}
catch (Exception)
{
Expand All @@ -78,6 +109,8 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
{
block.HttpStatusCode = httpStatusCode;
block.HeaderValues = headerValues;
block.StreamEvent = tempStreamEvent;
tempStreamEvent = null;
yield return block;
}
}
Expand Down
13 changes: 8 additions & 5 deletions OpenAI.SDK/Interfaces/IRunService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Runtime.CompilerServices;
using OpenAI.ObjectModels.RequestModels;
using OpenAI.ObjectModels.ResponseModels;
using OpenAI.ObjectModels.SharedModels;

namespace OpenAI.Interfaces;
Expand All @@ -24,8 +25,8 @@ public interface IRunService
/// <param name="modelId"></param>
/// <param name="justDataMode"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
IAsyncEnumerable<RunResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Retrieves a run.
Expand Down Expand Up @@ -71,9 +72,10 @@ public interface IRunService
/// <param name="threadId"></param>
/// <param name="runId"></param>
/// <param name="request"></param>
/// <param name="justDataMode"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
IAsyncEnumerable<RunResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Modifies a run.
Expand All @@ -93,7 +95,8 @@ public interface IRunService
/// <summary>
/// Create a thread and run it in one request as Stream.
/// </summary>
IAsyncEnumerable<RunResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Returns a list of runs belonging to a thread.
Expand Down
Loading