diff --git a/src/Services/LLmModelService.cs b/src/Services/LLmModelService.cs index 6582aaa..887e197 100644 --- a/src/Services/LLmModelService.cs +++ b/src/Services/LLmModelService.cs @@ -21,7 +21,6 @@ public class LLmModelService : ILLmModelService private readonly List _settings; private LLmModelSettings _usedset; private LLamaWeights _model; - private LLamaContext _context; private LLamaEmbedder? _embedder; // 已加载模型ID,-1表示未加载 @@ -75,7 +74,6 @@ public void InitModelIndex() DisposeModel(); _model = LLamaWeights.LoadFromFile(usedset.ModelParams); - _context = new LLamaContext(_model, usedset.ModelParams); if (usedset.ModelParams.Embeddings) { _embedder = new LLamaEmbedder(_model, usedset.ModelParams); @@ -97,7 +95,7 @@ public LLmModelService(IOptions> options, ILogger CreateChatCompletionAsync(ChatCompleti var genParams = GetInferenceParams(request); var chatHistory = GetChatHistory(request.messages); var lastMessage = chatHistory.Messages.LastOrDefault(); - + var context = new LLamaContext(_model, _usedset.ModelParams); // 没有消息 if (lastMessage is null) { @@ -135,7 +133,7 @@ public async Task CreateChatCompletionAsync(ChatCompleti // 去除最后一条消息 chatHistory.Messages.RemoveAt(chatHistory.Messages.Count - 1); - var session = GetChatSession(chatHistory); + var session = GetChatSession(chatHistory, context); var result = ""; await foreach (var output in session.ChatAsync(lastMessage, genParams)) @@ -144,8 +142,10 @@ public async Task CreateChatCompletionAsync(ChatCompleti result += output; } - var tokenizedInput = _context.Tokenize(lastMessage.Content); - var tokenizedOutput = _context.Tokenize(result); + var prompt_tokens = context.Tokenize(lastMessage.Content).Length; + var completion_tokens = context.Tokenize(result).Length; + + context.Dispose(); return new ChatCompletionResponse { id = $"chatcmpl-{Guid.NewGuid():N}", @@ -165,9 +165,9 @@ public async Task CreateChatCompletionAsync(ChatCompleti ], usage = new UsageInfo { - prompt_tokens = tokenizedInput.Length, - completion_tokens = tokenizedOutput.Length, - total_tokens = tokenizedInput.Length + tokenizedOutput.Length + prompt_tokens = prompt_tokens, + completion_tokens = completion_tokens, + total_tokens = prompt_tokens + completion_tokens } }; } @@ -182,6 +182,7 @@ public async IAsyncEnumerable CreateChatCompletionStreamAsync(ChatComple var genParams = GetInferenceParams(request); var chatHistory = GetChatHistory(request.messages); var lastMessage = chatHistory.Messages.LastOrDefault(); + var context = new LLamaContext(_model, _usedset.ModelParams); // 没有消息 if (lastMessage is null) @@ -193,7 +194,7 @@ public async IAsyncEnumerable CreateChatCompletionStreamAsync(ChatComple // 去除最后一条消息 chatHistory.Messages.RemoveAt(chatHistory.Messages.Count - 1); - var session = GetChatSession(chatHistory); + var session = GetChatSession(chatHistory, context); var id = $"chatcmpl-{Guid.NewGuid():N}"; var created = DateTimeOffset.Now.ToUnixTimeSeconds(); @@ -246,7 +247,7 @@ public async IAsyncEnumerable CreateChatCompletionStreamAsync(ChatComple yield return $"data: {chunk}\n\n"; } - //session.Executor.Context.Dispose(); + context.Dispose(); // 结束 chunk = JsonSerializer.Serialize(new ChatCompletionChunkResponse @@ -273,13 +274,12 @@ public async IAsyncEnumerable CreateChatCompletionStreamAsync(ChatComple /// 生成并配置对话会话 /// /// 历史对话 + /// /// - private ChatSession GetChatSession(ChatHistory chatHistory) + private ChatSession GetChatSession(ChatHistory chatHistory, LLamaContext context) { - var executor = new InteractiveExecutor(_context); + var executor = new InteractiveExecutor(context); ChatSession session = new(executor, chatHistory); - // 清除缓存 - session.Executor.Context.NativeHandle.KvCacheClear(); // 设置历史转换器和输出转换器 if (_usedset.WithTransform?.HistoryTransform != null) @@ -415,13 +415,16 @@ public async Task CreateCompletionAsync(CompletionRequest re var genParams = GetInferenceParams(request); var ex = new StatelessExecutor(_model, _usedset.ModelParams); var result = ""; + + var prompt_tokens = _model.Tokenize(request.prompt, true, false, _usedset.ModelParams.Encoding).Length; + var completion_tokens = 0; await foreach (var output in ex.InferAsync(request.prompt, genParams)) { _logger.LogDebug("Message: {output}", output); result += output; + completion_tokens++; } - var tokenizedInput = _context.Tokenize(request.prompt); - var tokenizedOutput = _context.Tokenize(result); + return new CompletionResponse { id = $"cmpl-{Guid.NewGuid():N}", @@ -437,9 +440,9 @@ public async Task CreateCompletionAsync(CompletionRequest re }, usage = new UsageInfo { - prompt_tokens = tokenizedInput.Length, - completion_tokens = tokenizedOutput.Length, - total_tokens = tokenizedInput.Length + tokenizedOutput.Length + prompt_tokens = prompt_tokens, + completion_tokens = completion_tokens, + total_tokens = prompt_tokens + completion_tokens } }; } @@ -565,7 +568,6 @@ public void DisposeModel() if (GlobalSettings.IsModelLoaded) { _embedder?.Dispose(); - _context.Dispose(); _model.Dispose(); GlobalSettings.IsModelLoaded = false; _loadModelIndex = -1;