Skip to content

Commit

Permalink
增加按次付费校验,如果按次付费则需要在请求前校验
Browse files Browse the repository at this point in the history
  • Loading branch information
239573049 committed Dec 7, 2024
1 parent 2cd4071 commit 2b2026c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 72 deletions.
131 changes: 69 additions & 62 deletions src/Thor.Service/Service/ChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,25 @@ public async Task CreateImageAsync(HttpContext context, ImageCreateRequest reque
organizationId = organizationIdHeader.ToString();
}

var (token, user) = await tokenService.CheckTokenAsync(context);

request.Model = TokenService.ModelMap(request.Model);

TokenService.CheckModel(request.Model, token, context);

if (string.IsNullOrEmpty(request?.Model)) request.Model = "dall-e-2";

await rateLimitModelService.CheckAsync(request.Model, context);

var imageCostRatio = GetImageCostRatio(request);

var rate = ModelManagerService.PromptRate[request.Model].PromptRate;
var rate = ModelManagerService.PromptRate[request.Model];

var (token, user) = await tokenService.CheckTokenAsync(context, rate);

TokenService.CheckModel(request.Model, token, context);


request.N ??= 1;

var quota = (int)(rate * imageCostRatio * 1000) * request.N;
var quota = (int)(rate.PromptRate * imageCostRatio * 1000) * request.N;

if (request == null) throw new Exception("模型校验异常");

Expand Down Expand Up @@ -213,61 +215,63 @@ public async ValueTask EmbeddingAsync(HttpContext context, ThorEmbeddingInput in

await rateLimitModelService.CheckAsync(input!.Model, context);

var (token, user) = await tokenService.CheckTokenAsync(context);
TokenService.CheckModel(input.Model, token, context);

// 获取渠道 通过算法计算权重
var channel = CalculateWeight(await channelService.GetChannelsContainsModelAsync(input.Model), input.Model);
if (ModelManagerService.PromptRate.TryGetValue(input.Model, out var rate))
{
var (token, user) = await tokenService.CheckTokenAsync(context, rate);
TokenService.CheckModel(input.Model, token, context);

if (channel == null) throw new NotModelException(input.Model);
// 获取渠道 通过算法计算权重
var channel = CalculateWeight(await channelService.GetChannelsContainsModelAsync(input.Model),
input.Model);

// 获取渠道指定的实现类型的服务
var embeddingService = GetKeyedService<IThorTextEmbeddingService>(channel.Type);
if (channel == null) throw new NotModelException(input.Model);

if (embeddingService == null) throw new Exception($"并未实现:{channel.Type} 的服务");
// 获取渠道指定的实现类型的服务
var embeddingService = GetKeyedService<IThorTextEmbeddingService>(channel.Type);

var embeddingCreateRequest = new EmbeddingCreateRequest
{
Model = input.Model,
EncodingFormat = input.EncodingFormat
};
if (embeddingService == null) throw new Exception($"并未实现:{channel.Type} 的服务");

int requestToken;
if (input.Input is JsonElement str)
{
if (str.ValueKind == JsonValueKind.String)
var embeddingCreateRequest = new EmbeddingCreateRequest
{
embeddingCreateRequest.Input = str.ToString();
requestToken = TokenHelper.GetTotalTokens(str.ToString());
}
else if (str.ValueKind == JsonValueKind.Array)
Model = input.Model,
EncodingFormat = input.EncodingFormat
};

int requestToken;
if (input.Input is JsonElement str)
{
var inputString = str.EnumerateArray().Select(x => x.ToString()).ToArray();
embeddingCreateRequest.InputAsList = inputString.ToList();
requestToken = TokenHelper.GetTotalTokens(inputString);
if (str.ValueKind == JsonValueKind.String)
{
embeddingCreateRequest.Input = str.ToString();
requestToken = TokenHelper.GetTotalTokens(str.ToString());
}
else if (str.ValueKind == JsonValueKind.Array)
{
var inputString = str.EnumerateArray().Select(x => x.ToString()).ToArray();
embeddingCreateRequest.InputAsList = inputString.ToList();
requestToken = TokenHelper.GetTotalTokens(inputString);
}
else
{
throw new Exception("输入格式错误");
}
}
else
{
throw new Exception("输入格式错误");
}
}
else
{
throw new Exception("输入格式错误");
}

var sw = Stopwatch.StartNew();
var sw = Stopwatch.StartNew();

var stream = await embeddingService.EmbeddingAsync(embeddingCreateRequest, new ThorPlatformOptions
{
ApiKey = channel.Key,
Address = channel.Address,
Other = channel.Other
}, context.RequestAborted);
sw.Stop();

var stream = await embeddingService.EmbeddingAsync(embeddingCreateRequest, new ThorPlatformOptions
{
ApiKey = channel.Key,
Address = channel.Address,
Other = channel.Other
}, context.RequestAborted);
sw.Stop();

if (ModelManagerService.PromptRate.TryGetValue(input.Model, out var rate))
{
var quota = requestToken * rate.PromptRate;

var completionRatio = GetCompletionRatio(input.Model);
Expand All @@ -284,11 +288,10 @@ await loggerService.CreateConsumeAsync(string.Format(ConsumerTemplate, rate, com

await userService.ConsumeAsync(user!.Id, (long)quota, requestToken, token?.Key, channel.Id,
input.Model);
}

stream.ConvertEmbeddingData(input.EncodingFormat);
stream.ConvertEmbeddingData(input.EncodingFormat);

await context.Response.WriteAsJsonAsync(stream);
await context.Response.WriteAsJsonAsync(stream);
}
}
catch (RateLimitException)
{
Expand Down Expand Up @@ -321,8 +324,6 @@ public async ValueTask CompletionsAsync(HttpContext context, CompletionCreateReq

await rateLimitModelService.CheckAsync(input!.Model, context);

var (token, user) = await tokenService.CheckTokenAsync(context);
TokenService.CheckModel(input.Model, token, context);

// 获取渠道 通过算法计算权重
var channel = CalculateWeight(await channelService.GetChannelsContainsModelAsync(input.Model), input.Model);
Expand All @@ -335,6 +336,9 @@ public async ValueTask CompletionsAsync(HttpContext context, CompletionCreateReq

if (ModelManagerService.PromptRate.TryGetValue(input.Model, out var rate))
{
var (token, user) = await tokenService.CheckTokenAsync(context, rate);
TokenService.CheckModel(input.Model, token, context);

if (input.Stream == false)
{
var sw = Stopwatch.StartNew();
Expand Down Expand Up @@ -430,10 +434,6 @@ public async ValueTask ChatCompletionsAsync(HttpContext context, ThorChatComplet
var model = request.Model;
await rateLimitModelService.CheckAsync(model, context);

var (token, user) = await tokenService.CheckTokenAsync(context);

TokenService.CheckModel(request.Model, token, context);

// 获取渠道通过算法计算权重
var channel = CalculateWeight(await channelService.GetChannelsContainsModelAsync(model), model);

Expand All @@ -442,8 +442,6 @@ public async ValueTask ChatCompletionsAsync(HttpContext context, ThorChatComplet
throw new NotModelException(model);
}

// 记录请求模型 / 请求用户
logger.LogInformation("请求模型:{model} 请求用户:{user}", model, user?.UserName);

// 获取渠道指定的实现类型的服务
var chatCompletionsService = GetKeyedService<IThorChatCompletionsService>(channel.Type);
Expand All @@ -455,6 +453,13 @@ public async ValueTask ChatCompletionsAsync(HttpContext context, ThorChatComplet

if (ModelManagerService.PromptRate.TryGetValue(model, out var rate))
{
var (token, user) = await tokenService.CheckTokenAsync(context, rate);

TokenService.CheckModel(request.Model, token, context);

// 记录请求模型 / 请求用户
logger.LogInformation("请求模型:{model} 请求用户:{user}", model, user?.UserName);

int requestToken;
var responseToken = 0;

Expand Down Expand Up @@ -550,6 +555,7 @@ await userService.ConsumeAsync(user!.Id, (long)rate.PromptRate, requestToken, to
{
context.Response.StatusCode = 402;
}

await context.WriteErrorAsync(insufficientQuotaException.Message, "402");
}
catch (RateLimitException)
Expand Down Expand Up @@ -606,9 +612,6 @@ public async ValueTask RealtimeAsync(HttpContext context)

await rateLimitModelService.CheckAsync(model, context);

var (token, user) = await tokenService.CheckTokenAsync(context);

TokenService.CheckModel(model, token, context);

// 获取渠道通过算法计算权重
var channel = CalculateWeight(await channelService.GetChannelsContainsModelAsync(model), model);
Expand All @@ -618,8 +621,6 @@ public async ValueTask RealtimeAsync(HttpContext context)
throw new NotModelException(model);
}

// 记录请求模型 / 请求用户
logger.LogInformation("请求模型:{model} 请求用户:{user}", model, user?.UserName);

// 获取渠道指定的实现类型的服务
var realtimeService = GetKeyedService<IThorRealtimeService>(channel.Type);
Expand All @@ -631,6 +632,12 @@ public async ValueTask RealtimeAsync(HttpContext context)

if (ModelManagerService.PromptRate.TryGetValue(model, out var rate))
{
var (token, user) = await tokenService.CheckTokenAsync(context, rate);

TokenService.CheckModel(model, token, context);
// 记录请求模型 / 请求用户
logger.LogInformation("请求模型:{model} 请求用户:{user}", model, user?.UserName);

decimal requestToken = 0;
decimal audioRequestToken = 0;
decimal responseToken = 0;
Expand Down
25 changes: 15 additions & 10 deletions src/Thor.Service/Service/TokenService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Thor.Service.Domain.Core;
using Thor.Service.Extensions;
using Thor.Service.Options;

Expand Down Expand Up @@ -91,7 +92,7 @@ public static void CheckModel(string model, Token? token, HttpContext context)
{
if (token == null) return;

if (token.LimitModels.Count > 0 && !token.LimitModels.Contains(model))
if (token.LimitModels.Count(x => !string.IsNullOrEmpty(x)) > 0 && !token.LimitModels.Contains(model))
{
throw new Exception("当前 Token 无权访问该模型");
}
Expand All @@ -111,22 +112,25 @@ public static void CheckModel(string model, Token? token, HttpContext context)
/// 检验账号额度是否足够
/// </summary>
/// <param name="context"></param>
/// <param name="value"></param>
/// <returns></returns>
/// <exception cref="UnauthorizedAccessException"></exception>
/// <exception cref="InsufficientQuotaException"></exception>
public async ValueTask<(Token?, User)> CheckTokenAsync(HttpContext context)
public async ValueTask<(Token?, User)> CheckTokenAsync(HttpContext context, ModelManager value)
{
var key = context.Request.Headers.Authorization.ToString().Replace("Bearer ", "").Trim();

if (string.IsNullOrEmpty(key))
{
var protocol = context.Request.Headers.SecWebSocketProtocol.ToString().Split(",").Select(x=>x.Trim());

var apiKey = protocol.FirstOrDefault(x => x.StartsWith("openai-insecure-api-key.", StringComparison.OrdinalIgnoreCase))?.Replace("openai-insecure-api-key.","");
if(!string.IsNullOrEmpty(apiKey))
var protocol = context.Request.Headers.SecWebSocketProtocol.ToString().Split(",").Select(x => x.Trim());

var apiKey = protocol
.FirstOrDefault(x => x.StartsWith("openai-insecure-api-key.", StringComparison.OrdinalIgnoreCase))
?.Replace("openai-insecure-api-key.", "");
if (!string.IsNullOrEmpty(apiKey))
key = apiKey;
}

var requestQuota = SettingService.GetIntSetting(SettingExtensions.GeneralSetting.RequestQuota);

if (requestQuota <= 0) requestQuota = 5000;
Expand Down Expand Up @@ -200,9 +204,10 @@ public static void CheckModel(string model, Token? token, HttpContext context)
throw new UnauthorizedAccessException("账号已禁用");
}

// 判断额度是否足够
if (user.ResidualCredit >= requestQuota)
if ((value.QuotaType == ModelQuotaType.ByCount && user.ResidualCredit >= value.PromptRate) || user.ResidualCredit >= requestQuota)
{
return (token, user);
}

logger.LogWarning("用户额度不足");
context.Response.StatusCode = 402;
Expand Down Expand Up @@ -235,7 +240,7 @@ public static string ModelMap(string model)
return chatChannel.Model;
}
}

modelMap?.SetTag("Model", models.LastOrDefault()?.Model ?? model);

return models.LastOrDefault()?.Model ?? model;
Expand Down

0 comments on commit 2b2026c

Please sign in to comment.