From 958ceae1f974d18125497f98499c1baeaf63919a Mon Sep 17 00:00:00 2001 From: token <239573049@qq.com> Date: Thu, 10 Oct 2024 02:20:46 +0800 Subject: [PATCH] =?UTF-8?q?-=20=E6=94=AF=E6=8C=81=E6=A0=B9=E6=8D=AE?= =?UTF-8?q?=E8=AF=B7=E6=B1=82=E7=9A=84=E5=90=91=E9=87=8F=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E5=B8=AE=E5=BF=99=E8=BD=AC=E6=8D=A2=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E7=9A=84=E5=B5=8C=E5=85=A5=E6=A8=A1=E5=9E=8B=E7=9A=84?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ResponseModels/EmbeddingCreateResponse.cs | 88 ++++++++++++++++--- src/Thor.Service/Program.cs | 13 ++- src/Thor.Service/Service/ChatService.cs | 9 +- 3 files changed, 85 insertions(+), 25 deletions(-) diff --git a/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs b/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs index 779aac2..0400328 100644 --- a/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs +++ b/src/Thor.Abstractions/ObjectModels/ObjectModels/ResponseModels/EmbeddingCreateResponse.cs @@ -1,25 +1,91 @@ -using System.Text.Json.Serialization; +using System.Buffers; +using System.Text.Json; +using System.Text.Json.Serialization; using Thor.Abstractions.Dtos; namespace Thor.Abstractions.ObjectModels.ObjectModels.ResponseModels; public record EmbeddingCreateResponse : ThorBaseResponse { - [JsonPropertyName("model")] - public string Model { get; set; } + [JsonPropertyName("model")] public string Model { get; set; } - [JsonPropertyName("data")] - public List Data { get; set; } + [JsonPropertyName("data")] public List Data { get; set; } = []; - [JsonPropertyName("usage")] - public ThorUsageResponse Usage { get; set; } + /// + /// 类型转换,如果类型是base64,则将float[]转换为base64,如果是空或是float和原始类型一样,则不转换 + /// + public void ConvertEmbeddingData(string? encodingFormat) + { + if (Data.Count == 0) + { + return; + } + + switch (encodingFormat) + { + // 判断第一个是否是float[],如果是则不转换 + case null or "float" when Data[0].Embedding is float[]: + return; + // 否则转换成float[] + case null or "float": + { + foreach (var embeddingResponse in Data) + { + if (embeddingResponse.Embedding is string base64) + { + embeddingResponse.Embedding = Convert.FromBase64String(base64); + } + } + + return; + } + // 判断第一个是否是string,如果是则不转换 + case "base64" when Data[0].Embedding is string: + return; + // 否则转换成base64 + case "base64": + { + foreach (var embeddingResponse in Data) + { + if (embeddingResponse.Embedding is JsonElement str) + { + if (str.ValueKind == JsonValueKind.Array) + { + var floats = str.EnumerateArray().Select(element => element.GetSingle()).ToArray(); + + embeddingResponse.Embedding = ConvertFloatArrayToBase64(floats); + } + } + } + + break; + } + } + } + + public static string ConvertFloatArrayToBase64(float[] floatArray) + { + // 将 float[] 转换成 byte[] + byte[] byteArray = ArrayPool.Shared.Rent(floatArray.Length * sizeof(float)); + try + { + Buffer.BlockCopy(floatArray, 0, byteArray, 0, byteArray.Length); + + // 将 byte[] 转换成 base64 字符串 + return Convert.ToBase64String(byteArray); + } + finally + { + ArrayPool.Shared.Return(byteArray); + } + } + + [JsonPropertyName("usage")] public ThorUsageResponse Usage { get; set; } } public record EmbeddingResponse { - [JsonPropertyName("index")] - public int? Index { get; set; } + [JsonPropertyName("index")] public int? Index { get; set; } - [JsonPropertyName("embedding")] - public object Embedding { get; set; } + [JsonPropertyName("embedding")] public object Embedding { get; set; } } \ No newline at end of file diff --git a/src/Thor.Service/Program.cs b/src/Thor.Service/Program.cs index 770d64f..83ceda6 100644 --- a/src/Thor.Service/Program.cs +++ b/src/Thor.Service/Program.cs @@ -4,6 +4,7 @@ using Microsoft.AspNetCore.Mvc; using Serilog; using Thor.Abstractions.Chats.Dtos; +using Thor.Abstractions.Embeddings.Dtos; using Thor.Abstractions.ObjectModels.ObjectModels.RequestModels; using Thor.AzureOpenAI.Extensions; using Thor.BuildingBlocks.Data; @@ -227,7 +228,6 @@ var loggerDbContext = scope.ServiceProvider.GetRequiredService(); await loggerDbContext.Database.EnsureCreatedAsync(); - } // 由于没有生成迁移记录,所以使用EnsureCreated else if (string.Equals(dbType, "postgresql") || string.Equals(dbType, "pgsql") || @@ -241,7 +241,6 @@ var loggerDbContext = scope.ServiceProvider.GetRequiredService(); await loggerDbContext.Database.EnsureCreatedAsync(); - } @@ -707,8 +706,9 @@ await service.CompletionsAsync(context)) .WithDescription("Get completions from OpenAI") .WithOpenApi(); - app.MapPost("/v1/embeddings", async (ChatService embeddingService, HttpContext context) => - await embeddingService.EmbeddingAsync(context)) + app.MapPost("/v1/embeddings", + async (ChatService embeddingService, HttpContext context, ThorEmbeddingInput module) => + await embeddingService.EmbeddingAsync(context, module)) .WithDescription("OpenAI") .WithDescription("Embedding") .WithOpenApi(); @@ -721,10 +721,7 @@ await imageService.CreateImageAsync(context, request)) .WithOpenApi(); - app.MapGet("/v1/models", async (HttpContext context) => - { - return await ModelService.GetAsync(context); - }) + app.MapGet("/v1/models", async (HttpContext context) => { return await ModelService.GetAsync(context); }) .WithDescription("获取模型列表") .RequireAuthorization() .WithOpenApi(); diff --git a/src/Thor.Service/Service/ChatService.cs b/src/Thor.Service/Service/ChatService.cs index 01cdefe..d574b18 100644 --- a/src/Thor.Service/Service/ChatService.cs +++ b/src/Thor.Service/Service/ChatService.cs @@ -162,15 +162,10 @@ await loggerService.CreateConsumeAsync(string.Format(ConsumerTemplate, rate, 0), } } - public async ValueTask EmbeddingAsync(HttpContext context) + public async ValueTask EmbeddingAsync(HttpContext context,ThorEmbeddingInput module) { try { - using var body = new MemoryStream(); - await context.Request.Body.CopyToAsync(body); - - var module = JsonSerializer.Deserialize(body.ToArray()); - if (module == null) throw new Exception("模型校验异常"); await rateLimitModelService.CheckAsync(module!.Model, context); @@ -246,6 +241,8 @@ await userService.ConsumeAsync(user!.Id, (long)quota, requestToken, token?.Key, module.Model); } + stream.ConvertEmbeddingData(module.EncodingFormat); + await context.Response.WriteAsJsonAsync(stream); } catch (RateLimitException)