Skip to content

Commit

Permalink
Merge pull request #25 from AIDotNet/feature/embedding_format
Browse files Browse the repository at this point in the history
- 支持根据请求的向量类型自动帮忙转换返回的嵌入模型的类型
  • Loading branch information
239573049 authored Oct 9, 2024
2 parents caca74f + 958ceae commit 9837d73
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,25 +1,91 @@
using System.Text.Json.Serialization;
using System.Buffers;
using System.Text.Json;
using System.Text.Json.Serialization;
using Thor.Abstractions.Dtos;

namespace Thor.Abstractions.ObjectModels.ObjectModels.ResponseModels;

public record EmbeddingCreateResponse : ThorBaseResponse
{
[JsonPropertyName("model")]
public string Model { get; set; }
[JsonPropertyName("model")] public string Model { get; set; }

[JsonPropertyName("data")]
public List<EmbeddingResponse> Data { get; set; }
[JsonPropertyName("data")] public List<EmbeddingResponse> Data { get; set; } = [];

[JsonPropertyName("usage")]
public ThorUsageResponse Usage { get; set; }
/// <summary>
/// 类型转换,如果类型是base64,则将float[]转换为base64,如果是空或是float和原始类型一样,则不转换
/// </summary>
public void ConvertEmbeddingData(string? encodingFormat)
{
if (Data.Count == 0)
{
return;
}

switch (encodingFormat)
{
// 判断第一个是否是float[],如果是则不转换
case null or "float" when Data[0].Embedding is float[]:
return;
// 否则转换成float[]
case null or "float":
{
foreach (var embeddingResponse in Data)
{
if (embeddingResponse.Embedding is string base64)
{
embeddingResponse.Embedding = Convert.FromBase64String(base64);
}
}

return;
}
// 判断第一个是否是string,如果是则不转换
case "base64" when Data[0].Embedding is string:
return;
// 否则转换成base64
case "base64":
{
foreach (var embeddingResponse in Data)
{
if (embeddingResponse.Embedding is JsonElement str)
{
if (str.ValueKind == JsonValueKind.Array)
{
var floats = str.EnumerateArray().Select(element => element.GetSingle()).ToArray();

embeddingResponse.Embedding = ConvertFloatArrayToBase64(floats);
}
}
}

break;
}
}
}

public static string ConvertFloatArrayToBase64(float[] floatArray)
{
// 将 float[] 转换成 byte[]
byte[] byteArray = ArrayPool<byte>.Shared.Rent(floatArray.Length * sizeof(float));
try
{
Buffer.BlockCopy(floatArray, 0, byteArray, 0, byteArray.Length);

// 将 byte[] 转换成 base64 字符串
return Convert.ToBase64String(byteArray);
}
finally
{
ArrayPool<byte>.Shared.Return(byteArray);
}
}

[JsonPropertyName("usage")] public ThorUsageResponse Usage { get; set; }
}

public record EmbeddingResponse
{
[JsonPropertyName("index")]
public int? Index { get; set; }
[JsonPropertyName("index")] public int? Index { get; set; }

[JsonPropertyName("embedding")]
public object Embedding { get; set; }
[JsonPropertyName("embedding")] public object Embedding { get; set; }
}
13 changes: 5 additions & 8 deletions src/Thor.Service/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.AspNetCore.Mvc;
using Serilog;
using Thor.Abstractions.Chats.Dtos;
using Thor.Abstractions.Embeddings.Dtos;
using Thor.Abstractions.ObjectModels.ObjectModels.RequestModels;
using Thor.AzureOpenAI.Extensions;
using Thor.BuildingBlocks.Data;
Expand Down Expand Up @@ -227,7 +228,6 @@

var loggerDbContext = scope.ServiceProvider.GetRequiredService<LoggerDbContext>();
await loggerDbContext.Database.EnsureCreatedAsync();

}
// 由于没有生成迁移记录,所以使用EnsureCreated
else if (string.Equals(dbType, "postgresql") || string.Equals(dbType, "pgsql") ||
Expand All @@ -241,7 +241,6 @@

var loggerDbContext = scope.ServiceProvider.GetRequiredService<LoggerDbContext>();
await loggerDbContext.Database.EnsureCreatedAsync();

}


Expand Down Expand Up @@ -707,8 +706,9 @@ await service.CompletionsAsync(context))
.WithDescription("Get completions from OpenAI")
.WithOpenApi();

app.MapPost("/v1/embeddings", async (ChatService embeddingService, HttpContext context) =>
await embeddingService.EmbeddingAsync(context))
app.MapPost("/v1/embeddings",
async (ChatService embeddingService, HttpContext context, ThorEmbeddingInput module) =>
await embeddingService.EmbeddingAsync(context, module))
.WithDescription("OpenAI")
.WithDescription("Embedding")
.WithOpenApi();
Expand All @@ -721,10 +721,7 @@ await imageService.CreateImageAsync(context, request))
.WithOpenApi();


app.MapGet("/v1/models", async (HttpContext context) =>
{
return await ModelService.GetAsync(context);
})
app.MapGet("/v1/models", async (HttpContext context) => { return await ModelService.GetAsync(context); })
.WithDescription("获取模型列表")
.RequireAuthorization()
.WithOpenApi();
Expand Down
9 changes: 3 additions & 6 deletions src/Thor.Service/Service/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,10 @@ await loggerService.CreateConsumeAsync(string.Format(ConsumerTemplate, rate, 0),
}
}

public async ValueTask EmbeddingAsync(HttpContext context)
public async ValueTask EmbeddingAsync(HttpContext context,ThorEmbeddingInput module)
{
try
{
using var body = new MemoryStream();
await context.Request.Body.CopyToAsync(body);

var module = JsonSerializer.Deserialize<ThorEmbeddingInput>(body.ToArray());

if (module == null) throw new Exception("模型校验异常");

await rateLimitModelService.CheckAsync(module!.Model, context);
Expand Down Expand Up @@ -246,6 +241,8 @@ await userService.ConsumeAsync(user!.Id, (long)quota, requestToken, token?.Key,
module.Model);
}

stream.ConvertEmbeddingData(module.EncodingFormat);

await context.Response.WriteAsJsonAsync(stream);
}
catch (RateLimitException)
Expand Down

0 comments on commit 9837d73

Please sign in to comment.