diff --git a/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs b/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs index 779aac2..0400328 100644 --- a/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs +++ b/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs @@ -1,25 +1,91 @@ -using System.Text.Json.Serialization; +using System.Buffers; +using System.Text.Json; +using System.Text.Json.Serialization; using Thor.Abstractions.Dtos; namespace Thor.Abstractions.ObjectModels.ObjectModels.ResponseModels; public record EmbeddingCreateResponse : ThorBaseResponse { - [JsonPropertyName("model")] - public string Model { get; set; } + [JsonPropertyName("model")] public string Model { get; set; } - [JsonPropertyName("data")] - public List Data { get; set; } + [JsonPropertyName("data")] public List Data { get; set; } = []; - [JsonPropertyName("usage")] - public ThorUsageResponse Usage { get; set; } + /// + /// 类型转换,如果类型是base64,则将float[]转换为base64,如果是空或是float和原始类型一样,则不转换 + /// + public void ConvertEmbeddingData(string? encodingFormat) + { + if (Data.Count == 0) + { + return; + } + + switch (encodingFormat) + { + // 判断第一个是否是float[],如果是则不转换 + case null or "float" when Data[0].Embedding is float[]: + return; + // 否则转换成float[] + case null or "float": + { + foreach (var embeddingResponse in Data) + { + if (embeddingResponse.Embedding is string base64) + { + embeddingResponse.Embedding = Convert.FromBase64String(base64); + } + } + + return; + } + // 判断第一个是否是string,如果是则不转换 + case "base64" when Data[0].Embedding is string: + return; + // 否则转换成base64 + case "base64": + { + foreach (var embeddingResponse in Data) + { + if (embeddingResponse.Embedding is JsonElement str) + { + if (str.ValueKind == JsonValueKind.Array) + { + var floats = str.EnumerateArray().Select(element => element.GetSingle()).ToArray(); + + embeddingResponse.Embedding = ConvertFloatArrayToBase64(floats); + } + } + } + + break; + } + } + } + + public static string ConvertFloatArrayToBase64(float[] floatArray) + { + // 将 float[] 转换成 byte[] + byte[] byteArray = ArrayPool.Shared.Rent(floatArray.Length * sizeof(float)); + try + { + Buffer.BlockCopy(floatArray, 0, byteArray, 0, byteArray.Length); + + // 将 byte[] 转换成 base64 字符串 + return Convert.ToBase64String(byteArray); + } + finally + { + ArrayPool.Shared.Return(byteArray); + } + } + + [JsonPropertyName("usage")] public ThorUsageResponse Usage { get; set; } } public record EmbeddingResponse { - [JsonPropertyName("index")] - public int? Index { get; set; } + [JsonPropertyName("index")] public int? Index { get; set; } - [JsonPropertyName("embedding")] - public object Embedding { get; set; } + [JsonPropertyName("embedding")] public object Embedding { get; set; } } \ No newline at end of file diff --git a/src/Thor.Service/Program.cs b/src/Thor.Service/Program.cs index 770d64f..83ceda6 100644 --- a/src/Thor.Service/Program.cs +++ b/src/Thor.Service/Program.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Mvc; using Serilog; using Thor.Abstractions.Chats.Dtos; +using Thor.Abstractions.Embeddings.Dtos; using Thor.Abstractions.ObjectModels.ObjectModels.RequestModels; using Thor.AzureOpenAI.Extensions; using Thor.BuildingBlocks.Data; @@ -227,7 +228,6 @@ var loggerDbContext = scope.ServiceProvider.GetRequiredService(); await loggerDbContext.Database.EnsureCreatedAsync(); - } // 由于没有生成迁移记录,所以使用EnsureCreated else if (string.Equals(dbType, "postgresql") || string.Equals(dbType, "pgsql") || @@ -241,7 +241,6 @@ var loggerDbContext = scope.ServiceProvider.GetRequiredService(); await loggerDbContext.Database.EnsureCreatedAsync(); - } @@ -707,8 +706,9 @@ await service.CompletionsAsync(context)) .WithDescription("Get completions from OpenAI") .WithOpenApi(); - app.MapPost("/v1/embeddings", async (ChatService embeddingService, HttpContext context) => - await embeddingService.EmbeddingAsync(context)) + app.MapPost("/v1/embeddings", + async (ChatService embeddingService, HttpContext context, ThorEmbeddingInput module) => + await embeddingService.EmbeddingAsync(context, module)) .WithDescription("OpenAI") .WithDescription("Embedding") .WithOpenApi(); @@ -721,10 +721,7 @@ await imageService.CreateImageAsync(context, request)) .WithOpenApi(); - app.MapGet("/v1/models", async (HttpContext context) => - { - return await ModelService.GetAsync(context); - }) + app.MapGet("/v1/models", async (HttpContext context) => { return await ModelService.GetAsync(context); }) .WithDescription("获取模型列表") .RequireAuthorization() .WithOpenApi(); diff --git a/src/Thor.Service/Service/ChatService.cs b/src/Thor.Service/Service/ChatService.cs index 01cdefe..d574b18 100644 --- a/src/Thor.Service/Service/ChatService.cs +++ b/src/Thor.Service/Service/ChatService.cs @@ -162,15 +162,10 @@ await loggerService.CreateConsumeAsync(string.Format(ConsumerTemplate, rate, 0), } } - public async ValueTask EmbeddingAsync(HttpContext context) + public async ValueTask EmbeddingAsync(HttpContext context,ThorEmbeddingInput module) { try { - using var body = new MemoryStream(); - await context.Request.Body.CopyToAsync(body); - - var module = JsonSerializer.Deserialize(body.ToArray()); - if (module == null) throw new Exception("模型校验异常"); await rateLimitModelService.CheckAsync(module!.Model, context); @@ -246,6 +241,8 @@ await userService.ConsumeAsync(user!.Id, (long)quota, requestToken, token?.Key, module.Model); } + stream.ConvertEmbeddingData(module.EncodingFormat); + await context.Response.WriteAsJsonAsync(stream); } catch (RateLimitException)