diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs
index 67e6df64b..9274b76ea 100644
--- a/service/Abstractions/Models/MemoryAnswer.cs
+++ b/service/Abstractions/Models/MemoryAnswer.cs
@@ -6,6 +6,7 @@
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
+using Microsoft.KernelMemory.Models;
namespace Microsoft.KernelMemory;
@@ -41,6 +42,14 @@ public class MemoryAnswer
[JsonPropertyOrder(10)]
public string Result { get; set; } = string.Empty;
+ ///
+ /// The tokens used by the model to generate the answer.
+ ///
+ /// Not all the models and text generators return token usage information.
+ [JsonPropertyName("tokenUsage")]
+ [JsonPropertyOrder(11)]
+ public TokenUsage TokenUsage { get; set; } = new();
+
///
/// List of the relevant sources used to produce the answer.
/// Key = Document ID
diff --git a/service/Abstractions/Models/TokeUsage.cs b/service/Abstractions/Models/TokeUsage.cs
new file mode 100644
index 000000000..09bee9ecc
--- /dev/null
+++ b/service/Abstractions/Models/TokeUsage.cs
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Text.Json.Serialization;
+
+namespace Microsoft.KernelMemory.Models;
+
+///
+/// Represents the usage of tokens in a request and response cycle.
+///
+public class TokenUsage
+{
+ ///
+ /// The number of tokens in the request message input, spanning all message content items.
+ ///
+ [JsonPropertyOrder(0)]
+ public int InputTokenCount { get; set; }
+
+ ///
+ /// The combined number of output tokens in the generated completion, as consumed by the model.
+ ///
+ [JsonPropertyOrder(1)]
+ public int OutputTokenCount { get; set; }
+
+ ///
+ /// The total number of combined input (prompt) and output (completion) tokens used.
+ ///
+ [JsonPropertyOrder(2)]
+ public int TotalTokenCount => this.InputTokenCount + this.OutputTokenCount;
+}
diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs
index ae37144db..b055733b5 100644
--- a/service/Core/Search/SearchClient.cs
+++ b/service/Core/Search/SearchClient.cs
@@ -268,7 +268,7 @@ public async Task AskAsync(
var fact = PromptUtils.RenderFactTemplate(
template: factTemplate,
factContent: partitionText,
- source: (fileName == "content.url" ? webPageUrl : fileName),
+ source: fileName == "content.url" ? webPageUrl : fileName,
relevance: relevance.ToString("P1", CultureInfo.CurrentCulture),
recordId: memory.Id,
tags: memory.Tags,
@@ -336,11 +336,15 @@ public async Task AskAsync(
return noAnswerFound;
}
+ var prompt = this.CreatePrompt(question, facts.ToString(), context);
+ answer.TokenUsage.InputTokenCount = this._textGenerator.CountTokens(prompt);
+
var text = new StringBuilder();
var charsGenerated = 0;
+
var watch = new Stopwatch();
watch.Restart();
- await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false))
+ await foreach (var x in this.GenerateAnswer(prompt, context, cancellationToken).ConfigureAwait(false))
{
text.Append(x);
@@ -354,6 +358,8 @@ public async Task AskAsync(
watch.Stop();
answer.Result = text.ToString();
+ answer.TokenUsage.OutputTokenCount = this._textGenerator.CountTokens(answer.Result);
+
this._log.LogSensitive("Answer: {0}", answer.Result);
answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
if (answer.NoResult)
@@ -381,12 +387,9 @@ public async Task AskAsync(
return answer;
}
- private IAsyncEnumerable GenerateAnswer(string question, string facts, IContext? context, CancellationToken token)
+ private string CreatePrompt(string question, string facts, IContext? context)
{
string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
- int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
- double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
- double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);
prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
@@ -395,6 +398,15 @@ private IAsyncEnumerable GenerateAnswer(string question, string facts, I
prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase);
prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);
+ return prompt;
+ }
+
+ private IAsyncEnumerable GenerateAnswer(string prompt, IContext? context, CancellationToken token)
+ {
+ int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
+ double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
+ double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);
+
var options = new TextGenerationOptions
{
MaxTokens = maxTokens,