Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add function to Service to allow stream the results back from a… #73

Merged
merged 8 commits into from
Jan 19, 2023
7 changes: 4 additions & 3 deletions OpenAI.Playground/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
//await ImageTestHelper.RunSimpleCreateImageEditTest(sdk);
//await ImageTestHelper.RunSimpleCreateImageVariationTest(sdk);
//await ModerationTestHelper.CreateModerationTest(sdk);
await CompletionTestHelper.RunSimpleCompletionTest(sdk);
await CompletionTestHelper.RunSimpleCompletionTest2(sdk);
await CompletionTestHelper.RunSimpleCompletionTest3(sdk);
//await CompletionTestHelper.RunSimpleCompletionTest(sdk);
//await CompletionTestHelper.RunSimpleCompletionTest2(sdk);
//await CompletionTestHelper.RunSimpleCompletionTest3(sdk);
await CompletionTestHelper.RunSimpleCompletionStreamTest(sdk);
//await EmbeddingTestHelper.RunSimpleEmbeddingTest(sdk);
//////await FileTestHelper.RunSimpleFileTest(sdk); //will delete files
//////await FineTuningTestHelper.CleanUpAllFineTunings(sdk); //!!!!! will delete all fine-tunings
Expand Down
40 changes: 40 additions & 0 deletions OpenAI.Playground/TestHelpers/CompletionTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,5 +122,45 @@ public static async Task RunSimpleCompletionTest3(IOpenAIService sdk)
throw;
}
}

public static async Task RunSimpleCompletionStreamTest(IOpenAIService sdk)
{
ConsoleExtensions.WriteLine("Completion Stream Testing is starting:", ConsoleColor.Cyan);

try
{
ConsoleExtensions.WriteLine("Completion Stream Test:", ConsoleColor.DarkCyan);
var completionResult = sdk.Completions.CreateCompletionAsStream(new CompletionCreateRequest()
{
Prompt = "Once upon a time",
MaxTokens = 50
}, Models.Davinci);

await foreach (var completion in completionResult)
{
if (completion.Successful)
{
Console.Write(completion.Choices.FirstOrDefault()?.Text);
}
else
{
if (completion.Error == null)
{
throw new Exception("Unknown Error");
}

Console.WriteLine($"{completion.Error.Code}: {completion.Error.Message}");
}
}

Console.WriteLine("");
Console.WriteLine("Complete");
}
catch (Exception e)
{
Console.WriteLine(e);
throw;
}
}
}
}
19 changes: 18 additions & 1 deletion OpenAI.SDK/Extensions/HttpclientExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Net.Http.Json;
using System.Net.Http.Headers;
using System.Net.Http.Json;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand All @@ -21,6 +22,22 @@ public static async Task<TResponse> PostAndReadAsAsync<TResponse>(this HttpClien
return await response.Content.ReadFromJsonAsync<TResponse>() ?? throw new InvalidOperationException();
}

public static HttpResponseMessage PostAsStreamAsync(this HttpClient client, string uri, object requestModel)
{
var settings = new JsonSerializerOptions()
{
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingDefault
};

var content = JsonContent.Create(requestModel, null, settings);

using var request = new HttpRequestMessage(HttpMethod.Post, uri);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream"));
request.Content = content;

return client.Send(request, HttpCompletionOption.ResponseHeadersRead);
}

public static async Task<TResponse> PostFileAndReadAsAsync<TResponse>(this HttpClient client, string uri, HttpContent content)
{
var response = await client.PostAsync(uri, content);
Expand Down
20 changes: 20 additions & 0 deletions OpenAI.SDK/Extensions/StringExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
namespace OpenAI.GPT3.Extensions
{
/// <summary>
/// Extension methods for string manipulation
/// </summary>
public static class StringExtensions
{
/// <summary>
/// Remove the search string from the begging of string if exist
/// </summary>
/// <param name="text"></param>
/// <param name="search"></param>
/// <returns></returns>
public static string RemoveIfStartWith(this string text, string search)
{
var pos = text.IndexOf(search, StringComparison.Ordinal);
return pos != 0 ? text : text[search.Length..];
}
}
}
18 changes: 13 additions & 5 deletions OpenAI.SDK/Interfaces/ICompletionService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@ public interface ICompletionService
/// <summary>
/// Creates a new completion for the provided prompt and parameters
/// </summary>
/// <param name="engineId">The ID of the engine to use for this request</param>
/// <param name="modelId">The ID of the engine to use for this request</param>
/// <param name="createCompletionModel"></param>
/// <returns></returns>
Task<CompletionCreateResponse> CreateCompletion(CompletionCreateRequest createCompletionModel, string? engineId = null);
Task<CompletionCreateResponse> CreateCompletion(CompletionCreateRequest createCompletionModel, string? modelId = null);

/// <summary>
/// Creates a new completion for the provided prompt and parameters and returns a stream of CompletionCreateRequests
/// </summary>
/// <param name="modelId">The ID of the engine to use for this request</param>
/// <param name="createCompletionModel"></param>
/// <returns></returns>
IAsyncEnumerable<CompletionCreateResponse> CreateCompletionAsStream(CompletionCreateRequest createCompletionModel, string? modelId = null);

/// <summary>
/// Creates a new completion for the provided prompt and parameters
/// </summary>
/// <param name="createCompletionModel"></param>
/// <param name="engineId">The ID of the engine to use for this request</param>
/// <param name="modelId">The ID of the engine to use for this request</param>
/// <returns></returns>
Task<CompletionCreateResponse> Create(CompletionCreateRequest createCompletionModel, Models.Model engineId)
Task<CompletionCreateResponse> Create(CompletionCreateRequest createCompletionModel, Models.Model modelId)
{
return CreateCompletion(createCompletionModel, engineId.EnumToString());
return CreateCompletion(createCompletionModel, modelId.EnumToString());
}
}
48 changes: 46 additions & 2 deletions OpenAI.SDK/Managers/OpenAICompletions.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,60 @@
using OpenAI.GPT3.Extensions;
using System.Text.Json;
using OpenAI.GPT3.Extensions;
using OpenAI.GPT3.Interfaces;
using OpenAI.GPT3.ObjectModels.RequestModels;
using OpenAI.GPT3.ObjectModels.ResponseModels;

namespace OpenAI.GPT3.Managers;


public partial class OpenAIService : ICompletionService
{
/// <inheritdoc />
public async Task<CompletionCreateResponse> CreateCompletion(CompletionCreateRequest createCompletionRequest, string? modelId = null)
{
createCompletionRequest.ProcessModelId(modelId, _defaultModelId);

return await _httpClient.PostAndReadAsAsync<CompletionCreateResponse>(_endpointProvider.CompletionCreate(), createCompletionRequest);
}

/// <inheritdoc />
public async IAsyncEnumerable<CompletionCreateResponse> CreateCompletionAsStream(CompletionCreateRequest createCompletionRequest, string? modelId = null)
{
// Mark the request as streaming
createCompletionRequest.Stream = true;

// Send the request to the CompletionCreate endpoint
createCompletionRequest.ProcessModelId(modelId, _defaultModelId);

using var response = _httpClient.PostAsStreamAsync(_endpointProvider.CompletionCreate(), createCompletionRequest);
await using var stream = await response.Content.ReadAsStreamAsync();
using var reader = new StreamReader(stream);
// Continuously read the stream until the end of it
while (!reader.EndOfStream)
{
var line = await reader.ReadLineAsync();
// Skip empty lines
if (string.IsNullOrEmpty(line)) continue;

line = line.RemoveIfStartWith("data: ");

// Exit the loop if the stream is done
if (line.StartsWith("[DONE]")) break;

CompletionCreateResponse? block;
try
{
// When the response is good, each line is a serializable CompletionCreateRequest
block = JsonSerializer.Deserialize<CompletionCreateResponse>(line);
}
catch (Exception)
{
// When the API returns an error, it does not come back as a block, it returns a single character of text ("{").
// In this instance, read through the rest of the response, which should be a complete object to parse.
line += await reader.ReadToEndAsync();
block = JsonSerializer.Deserialize<CompletionCreateResponse>(line);
}

if (null != block) yield return block;
}
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations;
using System.Text.Json.Serialization;
using OpenAI.GPT3.Interfaces;
using OpenAI.GPT3.ObjectModels.SharedModels;

namespace OpenAI.GPT3.ObjectModels.RequestModels
{
//TODO add model validation
//TODO check what is string or array for prompt,..
/// <summary>
/// Create Completion Request Model
/// </summary>
Expand Down