From 6d95553a553a7c79cc0d74af4a10a6773da6e993 Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Wed, 30 Oct 2024 17:16:58 +0100 Subject: [PATCH 1/2] Add token usage tracking - MemoryAnswer.cs: Imported TokenUsage class and added TokenUsage property. - SearchClient.cs: Refactored GenerateAnswer method, added prompt creation, token count logic, and updated RenderFactTemplate call. - TokenUsage.cs: Created TokenUsage class to track token counts. --- service/Abstractions/Models/MemoryAnswer.cs | 9 +++++++ service/Abstractions/Models/TokeUsage.cs | 29 +++++++++++++++++++++ service/Core/Search/SearchClient.cs | 26 +++++++++++++----- 3 files changed, 58 insertions(+), 6 deletions(-) create mode 100644 service/Abstractions/Models/TokeUsage.cs diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs index 67e6df64b..9274b76ea 100644 --- a/service/Abstractions/Models/MemoryAnswer.cs +++ b/service/Abstractions/Models/MemoryAnswer.cs @@ -6,6 +6,7 @@ using System.Text; using System.Text.Json; using System.Text.Json.Serialization; +using Microsoft.KernelMemory.Models; namespace Microsoft.KernelMemory; @@ -41,6 +42,14 @@ public class MemoryAnswer [JsonPropertyOrder(10)] public string Result { get; set; } = string.Empty; + /// + /// The tokens used by the model to generate the answer. + /// + /// Not all the models and text generators return token usage information. + [JsonPropertyName("tokenUsage")] + [JsonPropertyOrder(11)] + public TokenUsage TokenUsage { get; set; } = new(); + /// /// List of the relevant sources used to produce the answer. /// Key = Document ID diff --git a/service/Abstractions/Models/TokeUsage.cs b/service/Abstractions/Models/TokeUsage.cs new file mode 100644 index 000000000..09bee9ecc --- /dev/null +++ b/service/Abstractions/Models/TokeUsage.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; + +namespace Microsoft.KernelMemory.Models; + +/// +/// Represents the usage of tokens in a request and response cycle. +/// +public class TokenUsage +{ + /// + /// The number of tokens in the request message input, spanning all message content items. + /// + [JsonPropertyOrder(0)] + public int InputTokenCount { get; set; } + + /// + /// The combined number of output tokens in the generated completion, as consumed by the model. + /// + [JsonPropertyOrder(1)] + public int OutputTokenCount { get; set; } + + /// + /// The total number of combined input (prompt) and output (completion) tokens used. + /// + [JsonPropertyOrder(2)] + public int TotalTokenCount => this.InputTokenCount + this.OutputTokenCount; +} diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index ae37144db..1e1576a46 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -268,7 +268,7 @@ public async Task AskAsync( var fact = PromptUtils.RenderFactTemplate( template: factTemplate, factContent: partitionText, - source: (fileName == "content.url" ? webPageUrl : fileName), + source: fileName == "content.url" ? webPageUrl : fileName, relevance: relevance.ToString("P1", CultureInfo.CurrentCulture), recordId: memory.Id, tags: memory.Tags, @@ -336,11 +336,15 @@ public async Task AskAsync( return noAnswerFound; } + var (prompt, tokenCount) = this.CreatePrompt(question, facts.ToString(), context); + answer.TokenUsage.InputTokenCount = tokenCount; + var text = new StringBuilder(); var charsGenerated = 0; + var watch = new Stopwatch(); watch.Restart(); - await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + await foreach (var x in this.GenerateAnswer(prompt, context, cancellationToken).ConfigureAwait(false)) { text.Append(x); @@ -354,6 +358,8 @@ public async Task AskAsync( watch.Stop(); answer.Result = text.ToString(); + answer.TokenUsage.OutputTokenCount = this._textGenerator.CountTokens(answer.Result); + this._log.LogSensitive("Answer: {0}", answer.Result); answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer); if (answer.NoResult) @@ -381,12 +387,9 @@ public async Task AskAsync( return answer; } - private IAsyncEnumerable GenerateAnswer(string question, string facts, IContext? context, CancellationToken token) + private (string Text, int TokenCount) CreatePrompt(string question, string facts, IContext? context) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); - int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); - double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); - double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); @@ -395,6 +398,17 @@ private IAsyncEnumerable GenerateAnswer(string question, string facts, I prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); + var tokenCount = this._textGenerator.CountTokens(prompt); + + return (prompt, tokenCount); + } + + private IAsyncEnumerable GenerateAnswer(string prompt, IContext? context, CancellationToken token) + { + int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); + double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); + double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); + var options = new TextGenerationOptions { MaxTokens = maxTokens, From ee1269fde326e296f461423904e72e46bf8cd2ee Mon Sep 17 00:00:00 2001 From: Marco Minerva Date: Thu, 31 Oct 2024 09:34:42 +0100 Subject: [PATCH 2/2] Small refactoring --- service/Core/Search/SearchClient.cs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 1e1576a46..b055733b5 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -336,8 +336,8 @@ public async Task AskAsync( return noAnswerFound; } - var (prompt, tokenCount) = this.CreatePrompt(question, facts.ToString(), context); - answer.TokenUsage.InputTokenCount = tokenCount; + var prompt = this.CreatePrompt(question, facts.ToString(), context); + answer.TokenUsage.InputTokenCount = this._textGenerator.CountTokens(prompt); var text = new StringBuilder(); var charsGenerated = 0; @@ -387,7 +387,7 @@ public async Task AskAsync( return answer; } - private (string Text, int TokenCount) CreatePrompt(string question, string facts, IContext? context) + private string CreatePrompt(string question, string facts, IContext? context) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); @@ -398,9 +398,7 @@ public async Task AskAsync( prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); - var tokenCount = this._textGenerator.CountTokens(prompt); - - return (prompt, tokenCount); + return prompt; } private IAsyncEnumerable GenerateAnswer(string prompt, IContext? context, CancellationToken token)