diff --git a/CHANGELOG.md b/CHANGELOG.md index 08113ab1..2697eb9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Release History +## 2.0.0-beta.11 (Unreleased) + +### Features Added + +### Breaking Changes + +- Updated fine-tuning pagination methods `GetJobs`, `GetEvents`, and `GetJobCheckpoints` to return `IEnumerable` instead of `ClientResult`. (commit_hash) +- Updated the batching pagination method `GetBatches` to return `IEnumerable` instead of `ClientResult`. (commit_hash) + +### Bugs Fixed + +### Other Changes + ## 2.0.0-beta.10 (2024-08-26) ### Breaking Changes diff --git a/api/OpenAI.netstandard2.0.cs b/api/OpenAI.netstandard2.0.cs index 1a843e58..9db98764 100644 --- a/api/OpenAI.netstandard2.0.cs +++ b/api/OpenAI.netstandard2.0.cs @@ -1170,8 +1170,8 @@ public class BatchClient { public virtual Task CreateBatchAsync(BinaryContent content, RequestOptions options = null); public virtual ClientResult GetBatch(string batchId, RequestOptions options); public virtual Task GetBatchAsync(string batchId, RequestOptions options); - public virtual ClientResult GetBatches(string after, int? limit, RequestOptions options); - public virtual Task GetBatchesAsync(string after, int? limit, RequestOptions options); + public virtual IEnumerable GetBatches(string after, int? limit, RequestOptions options); + public virtual IAsyncEnumerable GetBatchesAsync(string after, int? limit, RequestOptions options); } } namespace OpenAI.Chat { @@ -1792,12 +1792,12 @@ public class FineTuningClient { public virtual Task CreateJobAsync(BinaryContent content, RequestOptions options = null); public virtual ClientResult GetJob(string jobId, RequestOptions options); public virtual Task GetJobAsync(string jobId, RequestOptions options); - public virtual ClientResult GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options); - public virtual Task GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options); - public virtual ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options); - public virtual Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options); - public virtual ClientResult GetJobs(string after, int? limit, RequestOptions options); - public virtual Task GetJobsAsync(string after, int? limit, RequestOptions options); + public virtual IEnumerable GetJobCheckpoints(string jobId, string after, int? limit, RequestOptions options); + public virtual IAsyncEnumerable GetJobCheckpointsAsync(string jobId, string after, int? limit, RequestOptions options); + public virtual IEnumerable GetJobEvents(string jobId, string after, int? limit, RequestOptions options); + public virtual IAsyncEnumerable GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options); + public virtual IEnumerable GetJobs(string after, int? limit, RequestOptions options); + public virtual IAsyncEnumerable GetJobsAsync(string after, int? limit, RequestOptions options); } } namespace OpenAI.Images { diff --git a/src/Custom/Batch/BatchClient.Protocol.cs b/src/Custom/Batch/BatchClient.Protocol.cs index 525dce1a..1505718a 100644 --- a/src/Custom/Batch/BatchClient.Protocol.cs +++ b/src/Custom/Batch/BatchClient.Protocol.cs @@ -1,6 +1,7 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Collections.Generic; using System.Threading.Tasks; namespace OpenAI.Batch; @@ -49,10 +50,10 @@ public virtual ClientResult CreateBatch(BinaryContent content, RequestOptions op /// The request options, which can override default behaviors of the client pipeline on a per-call basis. /// Service returned a non-success status code. /// The response returned from the service. - public virtual async Task GetBatchesAsync(string after, int? limit, RequestOptions options) + public virtual IAsyncEnumerable GetBatchesAsync(string after, int? limit, RequestOptions options) { - using PipelineMessage message = CreateGetBatchesRequest(after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + BatchesPageEnumerator enumerator = new BatchesPageEnumerator(_pipeline, _endpoint, after, limit, options); + return PageCollectionHelpers.CreateAsync(enumerator); } /// @@ -63,10 +64,10 @@ public virtual async Task GetBatchesAsync(string after, int? limit /// The request options, which can override default behaviors of the client pipeline on a per-call basis. /// Service returned a non-success status code. /// The response returned from the service. - public virtual ClientResult GetBatches(string after, int? limit, RequestOptions options) + public virtual IEnumerable GetBatches(string after, int? limit, RequestOptions options) { - using PipelineMessage message = CreateGetBatchesRequest(after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + BatchesPageEnumerator enumerator = new BatchesPageEnumerator(_pipeline, _endpoint, after, limit, options); + return PageCollectionHelpers.Create(enumerator); } /// diff --git a/src/Custom/Batch/Internal/Pagination/BatchesPageEnumerator.cs b/src/Custom/Batch/Internal/Pagination/BatchesPageEnumerator.cs new file mode 100644 index 00000000..f8055d3a --- /dev/null +++ b/src/Custom/Batch/Internal/Pagination/BatchesPageEnumerator.cs @@ -0,0 +1,108 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.Batch; + +internal partial class BatchesPageEnumerator : PageResultEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly int? _limit; + private readonly RequestOptions _options; + + private string _after; + + public BatchesPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string after, int? limit, + RequestOptions options) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetFirstAsync() + => await GetBatchesAsync(_after, _limit, _options).ConfigureAwait(false); + + public override ClientResult GetFirst() + => GetBatches(_after, _limit, _options); + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return await GetBatchesAsync(_after, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return GetBatches(_after, _limit, _options); + } + + public override bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + bool hasMore = doc.RootElement.GetProperty("has_more"u8).GetBoolean(); + + return hasMore; + } + + internal virtual async Task GetBatchesAsync(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetBatchesRequest(after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal virtual ClientResult GetBatches(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetBatchesRequest(after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetBatchesRequest(string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/v1/batches", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/Custom/FineTuning/FineTuningClient.Protocol.cs b/src/Custom/FineTuning/FineTuningClient.Protocol.cs index 2483478d..bccc9516 100644 --- a/src/Custom/FineTuning/FineTuningClient.Protocol.cs +++ b/src/Custom/FineTuning/FineTuningClient.Protocol.cs @@ -1,6 +1,7 @@ using System; using System.ClientModel; using System.ClientModel.Primitives; +using System.Collections.Generic; using System.Threading.Tasks; namespace OpenAI.FineTuning; @@ -76,10 +77,10 @@ public virtual ClientResult CreateJob(BinaryContent content, RequestOptions opti /// The request options, which can override default behaviors of the client pipeline on a per-call basis. /// Service returned a non-success status code. /// The response returned from the service. - public virtual async Task GetJobsAsync(string after, int? limit, RequestOptions options) + public virtual IAsyncEnumerable GetJobsAsync(string after, int? limit, RequestOptions options) { - using PipelineMessage message = CreateGetPaginatedFineTuningJobsRequest(after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + FineTuningJobsPageEnumerator enumerator = new FineTuningJobsPageEnumerator(_pipeline, _endpoint, after, limit, options); + return PageCollectionHelpers.CreateAsync(enumerator); } // CUSTOM: @@ -93,10 +94,10 @@ public virtual async Task GetJobsAsync(string after, int? limit, R /// The request options, which can override default behaviors of the client pipeline on a per-call basis. /// Service returned a non-success status code. /// The response returned from the service. - public virtual ClientResult GetJobs(string after, int? limit, RequestOptions options) + public virtual IEnumerable GetJobs(string after, int? limit, RequestOptions options) { - using PipelineMessage message = CreateGetPaginatedFineTuningJobsRequest(after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + FineTuningJobsPageEnumerator enumerator = new FineTuningJobsPageEnumerator(_pipeline, _endpoint, after, limit, options); + return PageCollectionHelpers.Create(enumerator); } // CUSTOM: @@ -197,12 +198,12 @@ public virtual ClientResult CancelJob(string jobId, RequestOptions options) /// is an empty string, and was expected to be non-empty. /// Service returned a non-success status code. /// The response returned from the service. - public virtual async Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) + public virtual IAsyncEnumerable GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) { Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + FineTuningJobEventsPageEnumerator enumerator = new FineTuningJobEventsPageEnumerator(_pipeline, _endpoint, jobId, after, limit, options); + return PageCollectionHelpers.CreateAsync(enumerator); } // CUSTOM: @@ -219,49 +220,49 @@ public virtual async Task GetJobEventsAsync(string jobId, string a /// is an empty string, and was expected to be non-empty. /// Service returned a non-success status code. /// The response returned from the service. - public virtual ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options) + public virtual IEnumerable GetJobEvents(string jobId, string after, int? limit, RequestOptions options) { Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + FineTuningJobEventsPageEnumerator enumerator = new FineTuningJobEventsPageEnumerator(_pipeline, _endpoint, jobId, after, limit, options); + return PageCollectionHelpers.Create(enumerator); } /// /// [Protocol Method] List the checkpoints for a fine-tuning job. /// - /// The ID of the fine-tuning job to get checkpoints for. + /// The ID of the fine-tuning job to get checkpoints for. /// Identifier for the last checkpoint ID from the previous pagination request. /// Number of checkpoints to retrieve. /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. + /// is null. + /// is an empty string, and was expected to be non-empty. /// Service returned a non-success status code. /// The response returned from the service. - public virtual async Task GetJobCheckpointsAsync(string fineTuningJobId, string after, int? limit, RequestOptions options) + public virtual IAsyncEnumerable GetJobCheckpointsAsync(string jobId, string after, int? limit, RequestOptions options) { - Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); - return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + FineTuningJobCheckpointsPageEnumerator enumerator = new FineTuningJobCheckpointsPageEnumerator(_pipeline, _endpoint, jobId, after, limit, options); + return PageCollectionHelpers.CreateAsync(enumerator); } /// /// [Protocol Method] List the checkpoints for a fine-tuning job. /// - /// The ID of the fine-tuning job to get checkpoints for. + /// The ID of the fine-tuning job to get checkpoints for. /// Identifier for the last checkpoint ID from the previous pagination request. /// Number of checkpoints to retrieve. /// The request options, which can override default behaviors of the client pipeline on a per-call basis. - /// is null. - /// is an empty string, and was expected to be non-empty. + /// is null. + /// is an empty string, and was expected to be non-empty. /// Service returned a non-success status code. /// The response returned from the service. - public virtual ClientResult GetJobCheckpoints(string fineTuningJobId, string after, int? limit, RequestOptions options) + public virtual IEnumerable GetJobCheckpoints(string jobId, string after, int? limit, RequestOptions options) { - Argument.AssertNotNullOrEmpty(fineTuningJobId, nameof(fineTuningJobId)); + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); - using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(fineTuningJobId, after, limit, options); - return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + FineTuningJobCheckpointsPageEnumerator enumerator = new FineTuningJobCheckpointsPageEnumerator(_pipeline, _endpoint, jobId, after, limit, options); + return PageCollectionHelpers.Create(enumerator); } } diff --git a/src/Custom/FineTuning/Internal/Pagination/FineTuningJobCheckpointsPageEnumerator.cs b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobCheckpointsPageEnumerator.cs new file mode 100644 index 00000000..f404409d --- /dev/null +++ b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobCheckpointsPageEnumerator.cs @@ -0,0 +1,116 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.FineTuning; + +internal partial class FineTuningJobCheckpointsPageEnumerator : PageResultEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly string _jobId; + private readonly int? _limit; + private readonly RequestOptions _options; + + private string _after; + + public FineTuningJobCheckpointsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string jobId, string after, int? limit, + RequestOptions options) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _jobId = jobId; + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetFirstAsync() + => await GetJobCheckpointsAsync(_jobId, _after, _limit, _options).ConfigureAwait(false); + + public override ClientResult GetFirst() + => GetJobCheckpoints(_jobId, _after, _limit, _options); + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return await GetJobCheckpointsAsync(_jobId, _after, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return GetJobCheckpoints(_jobId, _after, _limit, _options); + } + + public override bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + bool hasMore = doc.RootElement.GetProperty("has_more"u8).GetBoolean(); + + return hasMore; + } + + internal virtual async Task GetJobCheckpointsAsync(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal virtual ClientResult GetJobCheckpoints(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningJobCheckpointsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetFineTuningJobCheckpointsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/v1/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + uri.AppendPath("/checkpoints", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/Custom/FineTuning/Internal/Pagination/FineTuningJobEventsPageEnumerator.cs b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobEventsPageEnumerator.cs new file mode 100644 index 00000000..095bedfa --- /dev/null +++ b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobEventsPageEnumerator.cs @@ -0,0 +1,116 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.FineTuning; + +internal partial class FineTuningJobEventsPageEnumerator : PageResultEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly string _jobId; + private readonly int? _limit; + private readonly RequestOptions _options; + + private string _after; + + public FineTuningJobEventsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string jobId, string after, int? limit, + RequestOptions options) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _jobId = jobId; + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetFirstAsync() + => await GetJobEventsAsync(_jobId, _after, _limit, _options).ConfigureAwait(false); + + public override ClientResult GetFirst() + => GetJobEvents(_jobId, _after, _limit, _options); + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return await GetJobEventsAsync(_jobId, _after, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return GetJobEvents(_jobId, _after, _limit, _options); + } + + public override bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + bool hasMore = doc.RootElement.GetProperty("has_more"u8).GetBoolean(); + + return hasMore; + } + + internal virtual async Task GetJobEventsAsync(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal virtual ClientResult GetJobEvents(string jobId, string after, int? limit, RequestOptions options) + { + Argument.AssertNotNullOrEmpty(jobId, nameof(jobId)); + + using PipelineMessage message = CreateGetFineTuningEventsRequest(jobId, after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetFineTuningEventsRequest(string fineTuningJobId, string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/v1/fine_tuning/jobs/", false); + uri.AppendPath(fineTuningJobId, true); + uri.AppendPath("/events", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/Custom/FineTuning/Internal/Pagination/FineTuningJobsPageEnumerator.cs b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobsPageEnumerator.cs new file mode 100644 index 00000000..d9ad21b9 --- /dev/null +++ b/src/Custom/FineTuning/Internal/Pagination/FineTuningJobsPageEnumerator.cs @@ -0,0 +1,108 @@ +using System; +using System.ClientModel; +using System.ClientModel.Primitives; +using System.Text.Json; +using System.Threading.Tasks; + +#nullable enable + +namespace OpenAI.FineTuning; + +internal partial class FineTuningJobsPageEnumerator : PageResultEnumerator +{ + private readonly ClientPipeline _pipeline; + private readonly Uri _endpoint; + + private readonly int? _limit; + private readonly RequestOptions _options; + + private string _after; + + public FineTuningJobsPageEnumerator( + ClientPipeline pipeline, + Uri endpoint, + string after, int? limit, + RequestOptions options) + { + _pipeline = pipeline; + _endpoint = endpoint; + + _after = after; + _limit = limit; + _options = options; + } + + public override async Task GetFirstAsync() + => await GetJobsAsync(_after, _limit, _options).ConfigureAwait(false); + + public override ClientResult GetFirst() + => GetJobs(_after, _limit, _options); + + public override async Task GetNextAsync(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return await GetJobsAsync(_after, _limit, _options).ConfigureAwait(false); + } + + public override ClientResult GetNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + _after = doc.RootElement.GetProperty("last_id"u8).GetString()!; + + return GetJobs(_after, _limit, _options); + } + + public override bool HasNext(ClientResult result) + { + PipelineResponse response = result.GetRawResponse(); + + using JsonDocument doc = JsonDocument.Parse(response.Content); + bool hasMore = doc.RootElement.GetProperty("has_more"u8).GetBoolean(); + + return hasMore; + } + + internal virtual async Task GetJobsAsync(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetFineTuningJobsRequest(after, limit, options); + return ClientResult.FromResponse(await _pipeline.ProcessMessageAsync(message, options).ConfigureAwait(false)); + } + + internal virtual ClientResult GetJobs(string after, int? limit, RequestOptions options) + { + using PipelineMessage message = CreateGetFineTuningJobsRequest(after, limit, options); + return ClientResult.FromResponse(_pipeline.ProcessMessage(message, options)); + } + + internal PipelineMessage CreateGetFineTuningJobsRequest(string after, int? limit, RequestOptions options) + { + var message = _pipeline.CreateMessage(); + message.ResponseClassifier = PipelineMessageClassifier200; + var request = message.Request; + request.Method = "GET"; + var uri = new ClientUriBuilder(); + uri.Reset(_endpoint); + uri.AppendPath("/v1/fine_tuning/jobs", false); + if (after != null) + { + uri.AppendQuery("after", after, true); + } + if (limit != null) + { + uri.AppendQuery("limit", limit.Value, true); + } + request.Uri = uri.ToUri(); + request.Headers.Set("Accept", "application/json"); + message.Apply(options); + return message; + } + + private static PipelineMessageClassifier? _pipelineMessageClassifier200; + private static PipelineMessageClassifier PipelineMessageClassifier200 => _pipelineMessageClassifier200 ??= PipelineMessageClassifier.Create(stackalloc ushort[] { 200 }); +} diff --git a/src/Utility/PageCollectionHelpers.cs b/src/Utility/PageCollectionHelpers.cs index f7863ce2..b5be39a6 100644 --- a/src/Utility/PageCollectionHelpers.cs +++ b/src/Utility/PageCollectionHelpers.cs @@ -15,6 +15,22 @@ public static PageCollection Create(PageEnumerator enumerator) public static AsyncPageCollection CreateAsync(PageEnumerator enumerator) => new AsyncEnumeratorPageCollection(enumerator); + public static IEnumerable Create(PageResultEnumerator enumerator) + { + while (enumerator.MoveNext()) + { + yield return enumerator.Current; + } + } + + public static async IAsyncEnumerable CreateAsync(PageResultEnumerator enumerator) + { + while (await enumerator.MoveNextAsync().ConfigureAwait(false)) + { + yield return enumerator.Current; + } + } + private class EnumeratorPageCollection : PageCollection { private readonly PageEnumerator _enumerator; diff --git a/tests/Batch/BatchTests.cs b/tests/Batch/BatchTests.cs index 65be9eaa..8e51201a 100644 --- a/tests/Batch/BatchTests.cs +++ b/tests/Batch/BatchTests.cs @@ -4,6 +4,7 @@ using OpenAI.Tests.Utility; using System; using System.ClientModel; +using System.Collections.Generic; using System.IO; using System.Text.Json; using System.Threading.Tasks; @@ -22,32 +23,71 @@ public BatchTests(bool isAsync) : base(isAsync) } [Test] - public async Task ListBatchesProtocol() + public void ListBatchesProtocol() { BatchClient client = GetTestClient(); - ClientResult result = IsAsync - ? await client.GetBatchesAsync(after: null, limit: null, options: null) - : client.GetBatches(after: null, limit: null, options: null); + IEnumerable pageResults = client.GetBatches(after: null, limit: null, options: null); - BinaryData response = result.GetRawResponse().Content; - JsonDocument jsonDocument = JsonDocument.Parse(response); - JsonElement dataElement = jsonDocument.RootElement.GetProperty("data"); + int pageCount = 0; + foreach (ClientResult pageResult in pageResults) + { + BinaryData response = pageResult.GetRawResponse().Content; + using JsonDocument jsonDocument = JsonDocument.Parse(response); + JsonElement dataElement = jsonDocument.RootElement.GetProperty("data"); - Assert.That(dataElement.GetArrayLength(), Is.GreaterThan(0)); + Assert.That(dataElement.GetArrayLength(), Is.GreaterThan(0)); - long unixTime2024 = (new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero)).ToUnixTimeSeconds(); + long unixTime2024 = (new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero)).ToUnixTimeSeconds(); + + foreach (JsonElement batchElement in dataElement.EnumerateArray()) + { + JsonElement createdAtElement = batchElement.GetProperty("created_at"); + long createdAt = createdAtElement.GetInt64(); + + Assert.That(createdAt, Is.GreaterThan(unixTime2024)); + } + pageCount++; + + //var dynamicResult = result.GetRawResponse().Content.ToDynamicFromJson(); + //Assert.That(dynamicResult.data.Count, Is.GreaterThan(0)); + //Assert.That(dynamicResult.data[0].createdAt, Is.GreaterThan(new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero))); + } + + Assert.GreaterOrEqual(pageCount, 1); + } - foreach (JsonElement batchElement in dataElement.EnumerateArray()) + [Test] + public async Task ListBatchesProtocolAsync() + { + BatchClient client = GetTestClient(); + IAsyncEnumerable pageResults = client.GetBatchesAsync(after: null, limit: null, options: null); + + int pageCount = 0; + await foreach (ClientResult pageResult in pageResults) { - JsonElement createdAtElement = batchElement.GetProperty("created_at"); - long createdAt = createdAtElement.GetInt64(); + BinaryData response = pageResult.GetRawResponse().Content; + using JsonDocument jsonDocument = JsonDocument.Parse(response); + JsonElement dataElement = jsonDocument.RootElement.GetProperty("data"); + + Assert.That(dataElement.GetArrayLength(), Is.GreaterThan(0)); + + long unixTime2024 = (new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero)).ToUnixTimeSeconds(); + + foreach (JsonElement batchElement in dataElement.EnumerateArray()) + { + JsonElement createdAtElement = batchElement.GetProperty("created_at"); + long createdAt = createdAtElement.GetInt64(); + + Assert.That(createdAt, Is.GreaterThan(unixTime2024)); + } + pageCount++; - Assert.That(createdAt, Is.GreaterThan(unixTime2024)); + //var dynamicResult = result.GetRawResponse().Content.ToDynamicFromJson(); + //Assert.That(dynamicResult.data.Count, Is.GreaterThan(0)); + //Assert.That(dynamicResult.data[0].createdAt, Is.GreaterThan(new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero))); } - //var dynamicResult = result.GetRawResponse().Content.ToDynamicFromJson(); - //Assert.That(dynamicResult.data.Count, Is.GreaterThan(0)); - //Assert.That(dynamicResult.data[0].createdAt, Is.GreaterThan(new DateTimeOffset(2024, 01, 01, 0, 0, 0, TimeSpan.Zero))); + Assert.GreaterOrEqual(pageCount, 1); } [Test]