Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token usage tracking #872

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions service/Abstractions/Models/MemoryAnswer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.KernelMemory.Models;

namespace Microsoft.KernelMemory;

Expand Down Expand Up @@ -41,6 +42,14 @@ public class MemoryAnswer
[JsonPropertyOrder(10)]
public string Result { get; set; } = string.Empty;

/// <summary>
/// The tokens used by the model to generate the answer.
/// </summary>
/// <remarks>Not all the models and text generators return token usage information.</remarks>
[JsonPropertyName("tokenUsage")]
[JsonPropertyOrder(11)]
public TokenUsage TokenUsage { get; set; } = new();

/// <summary>
/// List of the relevant sources used to produce the answer.
/// Key = Document ID
Expand Down
29 changes: 29 additions & 0 deletions service/Abstractions/Models/TokeUsage.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;

namespace Microsoft.KernelMemory.Models;

/// <summary>
/// Represents the usage of tokens in a request and response cycle.
/// </summary>
public class TokenUsage
{
/// <summary>
/// The number of tokens in the request message input, spanning all message content items.
/// </summary>
[JsonPropertyOrder(0)]
public int InputTokenCount { get; set; }

/// <summary>
/// The combined number of output tokens in the generated completion, as consumed by the model.
/// </summary>
[JsonPropertyOrder(1)]
public int OutputTokenCount { get; set; }

/// <summary>
/// The total number of combined input (prompt) and output (completion) tokens used.
/// </summary>
[JsonPropertyOrder(2)]
public int TotalTokenCount => this.InputTokenCount + this.OutputTokenCount;
}
24 changes: 18 additions & 6 deletions service/Core/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ public async Task<MemoryAnswer> AskAsync(
var fact = PromptUtils.RenderFactTemplate(
template: factTemplate,
factContent: partitionText,
source: (fileName == "content.url" ? webPageUrl : fileName),
source: fileName == "content.url" ? webPageUrl : fileName,
relevance: relevance.ToString("P1", CultureInfo.CurrentCulture),
recordId: memory.Id,
tags: memory.Tags,
Expand Down Expand Up @@ -336,11 +336,15 @@ public async Task<MemoryAnswer> AskAsync(
return noAnswerFound;
}

var prompt = this.CreatePrompt(question, facts.ToString(), context);
answer.TokenUsage.InputTokenCount = this._textGenerator.CountTokens(prompt);

var text = new StringBuilder();
var charsGenerated = 0;

var watch = new Stopwatch();
watch.Restart();
await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false))
await foreach (var x in this.GenerateAnswer(prompt, context, cancellationToken).ConfigureAwait(false))
{
text.Append(x);

Expand All @@ -354,6 +358,8 @@ public async Task<MemoryAnswer> AskAsync(
watch.Stop();

answer.Result = text.ToString();
answer.TokenUsage.OutputTokenCount = this._textGenerator.CountTokens(answer.Result);

this._log.LogSensitive("Answer: {0}", answer.Result);
answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
if (answer.NoResult)
Expand Down Expand Up @@ -381,12 +387,9 @@ public async Task<MemoryAnswer> AskAsync(
return answer;
}

private IAsyncEnumerable<string> GenerateAnswer(string question, string facts, IContext? context, CancellationToken token)
private string CreatePrompt(string question, string facts, IContext? context)
{
string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);

prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);

Expand All @@ -395,6 +398,15 @@ private IAsyncEnumerable<string> GenerateAnswer(string question, string facts, I
prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase);
prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);

return prompt;
}

private IAsyncEnumerable<string> GenerateAnswer(string prompt, IContext? context, CancellationToken token)
{
int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);

var options = new TextGenerationOptions
{
MaxTokens = maxTokens,
Expand Down
Loading