Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- 支持根据请求的向量类型自动帮忙转换返回的嵌入模型的类型 #25

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,25 +1,91 @@
using System.Text.Json.Serialization;
using System.Buffers;
using System.Text.Json;
using System.Text.Json.Serialization;
using Thor.Abstractions.Dtos;

namespace Thor.Abstractions.ObjectModels.ObjectModels.ResponseModels;

public record EmbeddingCreateResponse : ThorBaseResponse
{
[JsonPropertyName("model")]
public string Model { get; set; }
[JsonPropertyName("model")] public string Model { get; set; }

[JsonPropertyName("data")]
public List<EmbeddingResponse> Data { get; set; }
[JsonPropertyName("data")] public List<EmbeddingResponse> Data { get; set; } = [];

[JsonPropertyName("usage")]
public ThorUsageResponse Usage { get; set; }
/// <summary>
/// 类型转换,如果类型是base64,则将float[]转换为base64,如果是空或是float和原始类型一样,则不转换
/// </summary>
public void ConvertEmbeddingData(string? encodingFormat)
{
if (Data.Count == 0)
{
return;
}

switch (encodingFormat)
{
// 判断第一个是否是float[],如果是则不转换
case null or "float" when Data[0].Embedding is float[]:
return;
// 否则转换成float[]
case null or "float":
{
foreach (var embeddingResponse in Data)
{
if (embeddingResponse.Embedding is string base64)
{
embeddingResponse.Embedding = Convert.FromBase64String(base64);
}
}

return;
}
// 判断第一个是否是string,如果是则不转换
case "base64" when Data[0].Embedding is string:
return;
// 否则转换成base64
case "base64":
{
foreach (var embeddingResponse in Data)
{
if (embeddingResponse.Embedding is JsonElement str)
{
if (str.ValueKind == JsonValueKind.Array)
{
var floats = str.EnumerateArray().Select(element => element.GetSingle()).ToArray();

embeddingResponse.Embedding = ConvertFloatArrayToBase64(floats);
}
}
}

break;
}
}
}

public static string ConvertFloatArrayToBase64(float[] floatArray)
{
// 将 float[] 转换成 byte[]
byte[] byteArray = ArrayPool<byte>.Shared.Rent(floatArray.Length * sizeof(float));
try
{
Buffer.BlockCopy(floatArray, 0, byteArray, 0, byteArray.Length);

// 将 byte[] 转换成 base64 字符串
return Convert.ToBase64String(byteArray);
}
finally
{
ArrayPool<byte>.Shared.Return(byteArray);
}
}

[JsonPropertyName("usage")] public ThorUsageResponse Usage { get; set; }
}

public record EmbeddingResponse
{
[JsonPropertyName("index")]
public int? Index { get; set; }
[JsonPropertyName("index")] public int? Index { get; set; }

[JsonPropertyName("embedding")]
public object Embedding { get; set; }
[JsonPropertyName("embedding")] public object Embedding { get; set; }
}
13 changes: 5 additions & 8 deletions src/Thor.Service/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.AspNetCore.Mvc;
using Serilog;
using Thor.Abstractions.Chats.Dtos;
using Thor.Abstractions.Embeddings.Dtos;
using Thor.Abstractions.ObjectModels.ObjectModels.RequestModels;
using Thor.AzureOpenAI.Extensions;
using Thor.BuildingBlocks.Data;
Expand Down Expand Up @@ -227,7 +228,6 @@

var loggerDbContext = scope.ServiceProvider.GetRequiredService<LoggerDbContext>();
await loggerDbContext.Database.EnsureCreatedAsync();

}
// 由于没有生成迁移记录,所以使用EnsureCreated
else if (string.Equals(dbType, "postgresql") || string.Equals(dbType, "pgsql") ||
Expand All @@ -241,7 +241,6 @@

var loggerDbContext = scope.ServiceProvider.GetRequiredService<LoggerDbContext>();
await loggerDbContext.Database.EnsureCreatedAsync();

}


Expand Down Expand Up @@ -707,8 +706,9 @@ await service.CompletionsAsync(context))
.WithDescription("Get completions from OpenAI")
.WithOpenApi();

app.MapPost("/v1/embeddings", async (ChatService embeddingService, HttpContext context) =>
await embeddingService.EmbeddingAsync(context))
app.MapPost("/v1/embeddings",
async (ChatService embeddingService, HttpContext context, ThorEmbeddingInput module) =>
await embeddingService.EmbeddingAsync(context, module))
.WithDescription("OpenAI")
.WithDescription("Embedding")
.WithOpenApi();
Expand All @@ -721,10 +721,7 @@ await imageService.CreateImageAsync(context, request))
.WithOpenApi();


app.MapGet("/v1/models", async (HttpContext context) =>
{
return await ModelService.GetAsync(context);
})
app.MapGet("/v1/models", async (HttpContext context) => { return await ModelService.GetAsync(context); })
.WithDescription("获取模型列表")
.RequireAuthorization()
.WithOpenApi();
Expand Down
9 changes: 3 additions & 6 deletions src/Thor.Service/Service/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,10 @@ await loggerService.CreateConsumeAsync(string.Format(ConsumerTemplate, rate, 0),
}
}

public async ValueTask EmbeddingAsync(HttpContext context)
public async ValueTask EmbeddingAsync(HttpContext context,ThorEmbeddingInput module)
{
try
{
using var body = new MemoryStream();
await context.Request.Body.CopyToAsync(body);

var module = JsonSerializer.Deserialize<ThorEmbeddingInput>(body.ToArray());

if (module == null) throw new Exception("模型校验异常");

await rateLimitModelService.CheckAsync(module!.Model, context);
Expand Down Expand Up @@ -246,6 +241,8 @@ await userService.ConsumeAsync(user!.Id, (long)quota, requestToken, token?.Key,
module.Model);
}

stream.ConvertEmbeddingData(module.EncodingFormat);

await context.Response.WriteAsJsonAsync(stream);
}
catch (RateLimitException)
Expand Down
Loading