diff --git a/.github/_typos.toml b/.github/_typos.toml
index 5fa7634f9..000d278e9 100644
--- a/.github/_typos.toml
+++ b/.github/_typos.toml
@@ -15,6 +15,7 @@ extend-exclude = [
"encoder.json",
"appsettings.development.json",
"appsettings.Development.json",
+ "appsettings.*.json.*",
"AzureAISearchFilteringTest.cs",
"KernelMemory.sln.DotSettings"
]
diff --git a/KernelMemory.sln.DotSettings b/KernelMemory.sln.DotSettings
index 4d9e859e9..56b15a2c0 100644
--- a/KernelMemory.sln.DotSettings
+++ b/KernelMemory.sln.DotSettings
@@ -120,6 +120,7 @@
SHA
SK
SKHTTP
+ SSE
SSL
TTL
UI
diff --git a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs
index 4446459ab..a4db5b6a0 100644
--- a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs
+++ b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs
@@ -404,6 +404,7 @@ public async Task AskAsync(
MemoryAnswer answer = await this._memory.AskAsync(
question: question,
index: index ?? this._defaultIndex,
+ options: new SearchOptions { Stream = false },
filter: TagsToMemoryFilter(tags ?? this._defaultRetrievalTags),
minRelevance: minRelevance,
cancellationToken: cancellationToken).ConfigureAwait(false);
diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs
index a89b25a55..f767881fd 100644
--- a/clients/dotnet/WebClient/MemoryWebClient.cs
+++ b/clients/dotnet/WebClient/MemoryWebClient.cs
@@ -7,11 +7,13 @@
using System.Linq;
using System.Net;
using System.Net.Http;
+using System.Runtime.CompilerServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.KernelMemory.Context;
+using Microsoft.KernelMemory.HTTP;
using Microsoft.KernelMemory.Internals;
namespace Microsoft.KernelMemory;
@@ -337,28 +339,30 @@ public async Task SearchAsync(
}
///
- public async Task AskAsync(
+ public async IAsyncEnumerable AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection? filters = null,
double minRelevance = 0,
+ SearchOptions? options = null,
IContext? context = null,
- CancellationToken cancellationToken = default)
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (filter != null)
{
- if (filters == null) { filters = []; }
-
+ filters ??= [];
filters.Add(filter);
}
+ var useStreaming = options?.Stream ?? false;
MemoryQuery request = new()
{
Index = index,
Question = question,
Filters = (filters is { Count: > 0 }) ? filters.ToList() : [],
MinRelevance = minRelevance,
+ Stream = useStreaming,
ContextArguments = (context?.Arguments ?? new Dictionary()).ToDictionary(),
};
using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");
@@ -367,8 +371,20 @@ public async Task AskAsync(
HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
- var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
- return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
+ if (useStreaming)
+ {
+ Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
+ IAsyncEnumerable answers = SSE.ParseStreamAsync(stream, cancellationToken);
+ await foreach (MemoryAnswer answer in answers.ConfigureAwait(false))
+ {
+ yield return answer;
+ }
+ }
+ else
+ {
+ var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);
+ yield return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer();
+ }
}
#region private
diff --git a/examples/001-dotnet-WebClient/Program.cs b/examples/001-dotnet-WebClient/Program.cs
index 0eec905c4..f1e6d4db4 100644
--- a/examples/001-dotnet-WebClient/Program.cs
+++ b/examples/001-dotnet-WebClient/Program.cs
@@ -15,7 +15,7 @@
* without extracting memories. */
internal static class Program
{
- private static MemoryWebClient? s_memory;
+ private static MemoryWebClient s_memory = null!;
private static readonly List s_toDelete = [];
// Change this to True and configure Azure Document Intelligence to test OCR and support for images
@@ -55,8 +55,8 @@ public static async Task Main()
// === RETRIEVAL =========
// =======================
- await AskSimpleQuestion();
- await AskSimpleQuestionAndShowSources();
+ await AskSimpleQuestionStreamingTheAnswer();
+ await AskSimpleQuestionStreamingAndShowSources();
await AskQuestionAboutImageContent();
await AskQuestionUsingFilter();
await AskQuestionsFilteringByUser();
@@ -249,16 +249,25 @@ private static async Task StoreJson()
// =======================
// Question without filters
- private static async Task AskSimpleQuestion()
+ private static async Task AskSimpleQuestionStreamingTheAnswer()
{
var question = "What's E = m*c^2?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: formula explanation using the information loaded");
- var answer = await s_memory.AskAsync(question, minRelevance: 0.6);
- Console.WriteLine($"\nAnswer: {answer.Result}");
+ Console.Write("\nAnswer: ");
+ var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6,
+ options: new SearchOptions { Stream = true });
- Console.WriteLine("\n====================================\n");
+ await foreach (var answer in answerStream)
+ {
+ // Print token received by LLM
+ Console.Write(answer.Result);
+ // Slow down the stream for demo purpose
+ await Task.Delay(25);
+ }
+
+ Console.WriteLine("\n\n====================================\n");
/* OUTPUT
@@ -275,17 +284,32 @@ due to the speed of light being a very large number when squared. This concept i
}
// Another question without filters and show sources
- private static async Task AskSimpleQuestionAndShowSources()
+ private static async Task AskSimpleQuestionStreamingAndShowSources()
{
var question = "What's Kernel Memory?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)");
- var answer = await s_memory.AskAsync(question, minRelevance: 0.5);
- Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n");
+ Console.Write("\nAnswer: ");
+ var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5,
+ options: new SearchOptions { Stream = true });
+
+ List sources = [];
+ await foreach (var answer in answerStream)
+ {
+ // Print token received by LLM
+ Console.Write(answer.Result);
+
+ // Collect sources
+ sources.AddRange(answer.RelevantSources);
+
+ // Slow down the stream for demo purpose
+ await Task.Delay(5);
+ }
// Show sources / citations
- foreach (var x in answer.RelevantSources)
+ Console.WriteLine("\n\nSources:\n");
+ foreach (var x in sources)
{
Console.WriteLine(x.SourceUrl != null
? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]"
diff --git a/examples/002-dotnet-Serverless/Program.cs b/examples/002-dotnet-Serverless/Program.cs
index ee5f680d6..26a9fb5c8 100644
--- a/examples/002-dotnet-Serverless/Program.cs
+++ b/examples/002-dotnet-Serverless/Program.cs
@@ -13,7 +13,7 @@
#pragma warning disable CS8602 // by design
public static class Program
{
- private static MemoryServerless? s_memory;
+ private static MemoryServerless s_memory = null!;
private static readonly List s_toDelete = [];
// Remember to configure Azure Document Intelligence to test OCR and support for images
@@ -107,8 +107,8 @@ public static async Task Main()
// === RETRIEVAL =========
// =======================
- await AskSimpleQuestion();
- await AskSimpleQuestionAndShowSources();
+ await AskSimpleQuestionStreamingTheAnswer();
+ await AskSimpleQuestionStreamingAndShowSources();
await AskQuestionAboutImageContent();
await AskQuestionUsingFilter();
await AskQuestionsFilteringByUser();
@@ -303,16 +303,25 @@ private static async Task StoreJson()
// =======================
// Question without filters
- private static async Task AskSimpleQuestion()
+ private static async Task AskSimpleQuestionStreamingTheAnswer()
{
var question = "What's E = m*c^2?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: formula explanation using the information loaded");
- var answer = await s_memory.AskAsync(question, minRelevance: 0.6);
- Console.WriteLine($"\nAnswer: {answer.Result}");
+ Console.Write("\nAnswer: ");
+ var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6,
+ options: new SearchOptions { Stream = true });
- Console.WriteLine("\n====================================\n");
+ await foreach (var answer in answerStream)
+ {
+ // Print token received by LLM
+ Console.Write(answer.Result);
+ // Slow down the stream for demo purpose
+ await Task.Delay(25);
+ }
+
+ Console.WriteLine("\n\n====================================\n");
/* OUTPUT
@@ -329,17 +338,32 @@ due to the speed of light being a very large number when squared. This concept i
}
// Another question without filters and show sources
- private static async Task AskSimpleQuestionAndShowSources()
+ private static async Task AskSimpleQuestionStreamingAndShowSources()
{
var question = "What's Kernel Memory?";
Console.WriteLine($"Question: {question}");
Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)");
- var answer = await s_memory.AskAsync(question, minRelevance: 0.5);
- Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n");
+ Console.Write("\nAnswer: ");
+ var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5,
+ options: new SearchOptions { Stream = true });
+
+ List sources = [];
+ await foreach (var answer in answerStream)
+ {
+ // Print token received by LLM
+ Console.Write(answer.Result);
+
+ // Collect sources
+ sources.AddRange(answer.RelevantSources);
+
+ // Slow down the stream for demo purpose
+ await Task.Delay(5);
+ }
// Show sources / citations
- foreach (var x in answer.RelevantSources)
+ Console.WriteLine("\n\nSources:\n");
+ foreach (var x in sources)
{
Console.WriteLine(x.SourceUrl != null
? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]"
diff --git a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs
index 93618ad48..3aee0565b 100644
--- a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs
+++ b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs
@@ -8,6 +8,7 @@
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.MemoryDb.AzureAISearch;
using Microsoft.KernelMemory.MemoryStorage;
+using AISearchOptions = Azure.Search.Documents.SearchOptions;
namespace Microsoft.AzureAISearch.TestApplication;
@@ -246,7 +247,7 @@ private static async Task> SearchByFieldValueAsync(
fieldValue1 = fieldValue1.Replace("'", "''", StringComparison.Ordinal);
fieldValue2 = fieldValue2.Replace("'", "''", StringComparison.Ordinal);
- SearchOptions options = new()
+ AISearchOptions options = new()
{
Filter = fieldIsCollection
? $"{fieldName}/any(s: s eq '{fieldValue1}') and {fieldName}/any(s: s eq '{fieldValue2}')"
diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
index fb838e862..b8ad546c6 100644
--- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
+++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs
@@ -19,6 +19,7 @@
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.DocumentStorage;
using Microsoft.KernelMemory.MemoryStorage;
+using AISearchOptions = Azure.Search.Documents.SearchOptions;
namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch;
@@ -184,7 +185,7 @@ await client.IndexDocumentsAsync(
Exhaustive = false
};
- SearchOptions options = new()
+ AISearchOptions options = new()
{
VectorSearch = new()
{
@@ -246,7 +247,7 @@ public async IAsyncEnumerable GetListAsync(
{
var client = this.GetSearchClient(index);
- SearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit);
+ AISearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit);
Response>? searchResult = null;
try
@@ -596,13 +597,13 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options,
return indexSchema;
}
- private SearchOptions PrepareSearchOptions(
- SearchOptions? options,
+ private AISearchOptions PrepareSearchOptions(
+ AISearchOptions? options,
bool withEmbeddings,
ICollection? filters = null,
int limit = 1)
{
- options ??= new SearchOptions();
+ options ??= new AISearchOptions();
// Define which fields to fetch
options.Select.Add(AzureAISearchMemoryRecord.IdField);
diff --git a/service/Abstractions/Abstractions.csproj b/service/Abstractions/Abstractions.csproj
index 6fab4f39c..caac4445d 100644
--- a/service/Abstractions/Abstractions.csproj
+++ b/service/Abstractions/Abstractions.csproj
@@ -4,7 +4,7 @@
net8.0
Microsoft.KernelMemory.Abstractions
Microsoft.KernelMemory
- $(NoWarn);KMEXP00;CA1711;CA1724;CS1574;SKEXP0001;
+ $(NoWarn);KMEXP00;SKEXP0001;CA1711;CA1724;CS1574;CA1812;
@@ -13,6 +13,7 @@
+
diff --git a/service/Abstractions/HTTP/SSE.cs b/service/Abstractions/HTTP/SSE.cs
new file mode 100644
index 000000000..e6d9254fd
--- /dev/null
+++ b/service/Abstractions/HTTP/SSE.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+
+namespace Microsoft.KernelMemory.HTTP;
+
+// See https://developer.mozilla.org/docs/Web/API/Server-sent_events/Using_server-sent_events
+public static class SSE
+{
+ public const string DataPrefix = "data: ";
+ public const string LastToken = "[DONE]";
+ public const string DoneMessage = $"{DataPrefix}{LastToken}";
+
+ public async static IAsyncEnumerable ParseStreamAsync(
+ Stream stream, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ {
+ using var reader = new StreamReader(stream, Encoding.UTF8);
+ StringBuilder buffer = new();
+
+ while (await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is { } line)
+ {
+ if (string.IsNullOrWhiteSpace(line)) // \n\n detected => Message delimiter
+ {
+ if (buffer.Length == 0) { continue; }
+
+ string message = buffer.ToString();
+ buffer.Clear();
+ if (message.Trim() == DoneMessage) { yield break; }
+
+ var memoryAnswer = ParseMessage(message);
+ if (memoryAnswer != null) { yield return memoryAnswer; }
+ }
+ else
+ {
+ buffer.AppendLine(line);
+ }
+ }
+
+ // Process any remaining text as the last message
+ if (buffer.Length > 0)
+ {
+ string message = buffer.ToString();
+ if (message.Trim() == DoneMessage) { yield break; }
+
+ var memoryAnswer = ParseMessage(message);
+ if (memoryAnswer != null) { yield return memoryAnswer; }
+ }
+ }
+
+ public static T? ParseMessage(string? message)
+ {
+ if (string.IsNullOrWhiteSpace(message)) { return default; }
+
+ string json = string.Join("",
+ message.Split('\n', StringSplitOptions.RemoveEmptyEntries)
+ .Where(line => line.StartsWith(DataPrefix, StringComparison.OrdinalIgnoreCase))
+ .Select(line => line[DataPrefix.Length..]));
+
+ return JsonSerializer.Deserialize(json);
+ }
+}
diff --git a/service/Abstractions/IKernelMemory.cs b/service/Abstractions/IKernelMemory.cs
index 89dc57009..f1152ff6a 100644
--- a/service/Abstractions/IKernelMemory.cs
+++ b/service/Abstractions/IKernelMemory.cs
@@ -211,21 +211,28 @@ public Task SearchAsync(
///
/// Search the given index for an answer to the given query.
+ ///
+ /// Use this method to work with IAsyncEnumerable and optionally stream the output.
+ /// - Note: you must set options.Stream = true to enable token streaming.
+ ///
+ /// Use the .AskAsync() extension method to receive the complete answer without streaming.
///
/// Question to answer
/// Optional index name
/// Filter to match
/// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list.
/// Minimum Cosine Similarity required
+ /// Options for the request, such as whether to stream results
/// Unstructured data supporting custom business logic in the current request.
/// Async task cancellation token
/// Answer to the query, if possible
- public Task AskAsync(
+ public IAsyncEnumerable AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection? filters = null,
double minRelevance = 0,
+ SearchOptions? options = null,
IContext? context = null,
CancellationToken cancellationToken = default);
}
diff --git a/service/Abstractions/KernelMemoryExtensions.cs b/service/Abstractions/KernelMemoryExtensions.cs
index 62443a890..0e5dfb53a 100644
--- a/service/Abstractions/KernelMemoryExtensions.cs
+++ b/service/Abstractions/KernelMemoryExtensions.cs
@@ -1,8 +1,10 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
+using Microsoft.KernelMemory.Context;
namespace Microsoft.KernelMemory;
@@ -11,6 +13,47 @@ namespace Microsoft.KernelMemory;
///
public static class KernelMemoryExtensions
{
+ ///
+ /// Search the given index for an answer to the given query
+ /// and return it without streaming the content.
+ ///
+ /// Memory instance
+ /// Question to answer
+ /// Optional index name
+ /// Filter to match
+ /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list.
+ /// Minimum Cosine Similarity required
+ /// Options for the request, such as whether to stream results
+ /// Unstructured data supporting custom business logic in the current request.
+ /// Async task cancellation token
+ /// Answer to the query, if possible
+ public static async Task AskAsync(
+ this IKernelMemory memory,
+ string question,
+ string? index = null,
+ MemoryFilter? filter = null,
+ ICollection? filters = null,
+ double minRelevance = 0,
+ SearchOptions? options = null,
+ IContext? context = null,
+ CancellationToken cancellationToken = default)
+ {
+ var optionsOverride = options.Clone() ?? new SearchOptions();
+ optionsOverride.Stream = false;
+
+ return await memory.AskStreamingAsync(
+ question: question,
+ index: index,
+ filter: filter,
+ filters: filters,
+ minRelevance: minRelevance,
+ options: optionsOverride,
+ context: context,
+ cancellationToken)
+ .FirstAsync(cancellationToken: cancellationToken)
+ .ConfigureAwait(false);
+ }
+
///
/// Return a list of synthetic memories of the specified type
///
diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs
index eeb9d4d03..c78e695c6 100644
--- a/service/Abstractions/Models/MemoryAnswer.cs
+++ b/service/Abstractions/Models/MemoryAnswer.cs
@@ -11,15 +11,20 @@ namespace Microsoft.KernelMemory;
public class MemoryAnswer
{
- private static readonly JsonSerializerOptions s_indentedJsonOptions = new() { WriteIndented = true };
- private static readonly JsonSerializerOptions s_notIndentedJsonOptions = new() { WriteIndented = false };
- private static readonly JsonSerializerOptions s_caseInsensitiveJsonOptions = new() { PropertyNameCaseInsensitive = true };
+ ///
+ /// Used only when streaming. How to handle the current record.
+ ///
+ [JsonPropertyName("streamState")]
+ [JsonPropertyOrder(0)]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
+ public StreamStates? StreamState { get; set; } = null;
///
/// Client question.
///
[JsonPropertyName("question")]
[JsonPropertyOrder(1)]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public string Question { get; set; } = string.Empty;
[JsonPropertyName("noResult")]
@@ -48,23 +53,31 @@ public class MemoryAnswer
///
[JsonPropertyName("relevantSources")]
[JsonPropertyOrder(20)]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public List RelevantSources { get; set; } = [];
///
/// Serialize using .NET JSON serializer, e.g. to avoid ambiguity
/// with other serializers and other options
///
- /// Whether to keep the JSON readable, e.g. for debugging and views
+ /// Whether to reduce the payload size for SSE
/// JSON serialization
- public string ToJson(bool indented = false)
+ public string ToJson(bool optimizeForStream)
{
- return JsonSerializer.Serialize(this, indented ? s_indentedJsonOptions : s_notIndentedJsonOptions);
- }
+ if (!optimizeForStream || this.StreamState != StreamStates.Append)
+ {
+ return JsonSerializer.Serialize(this);
+ }
- public MemoryAnswer FromJson(string json)
- {
- return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions)
- ?? new MemoryAnswer();
+ MemoryAnswer clone = JsonSerializer.Deserialize(JsonSerializer.Serialize(this))!;
+
+#pragma warning disable CA1820
+ if (clone.Question == string.Empty) { clone.Question = null!; }
+#pragma warning restore CA1820
+
+ if (clone.RelevantSources.Count == 0) { clone.RelevantSources = null!; }
+
+ return JsonSerializer.Serialize(clone);
}
public override string ToString()
@@ -72,7 +85,7 @@ public override string ToString()
var result = new StringBuilder();
result.AppendLine(this.Result);
- if (!this.NoResult)
+ if (!this.NoResult && this.RelevantSources is { Count: > 0 })
{
var sources = new Dictionary();
foreach (var x in this.RelevantSources)
diff --git a/service/Abstractions/Models/MemoryQuery.cs b/service/Abstractions/Models/MemoryQuery.cs
index f6e8dfae4..3676b945a 100644
--- a/service/Abstractions/Models/MemoryQuery.cs
+++ b/service/Abstractions/Models/MemoryQuery.cs
@@ -23,6 +23,10 @@ public class MemoryQuery
[JsonPropertyOrder(2)]
public double MinRelevance { get; set; } = 0;
+ [JsonPropertyName("stream")]
+ [JsonPropertyOrder(3)]
+ public bool Stream { get; set; } = false;
+
[JsonPropertyName("args")]
[JsonPropertyOrder(100)]
public Dictionary ContextArguments { get; set; } = [];
diff --git a/service/Abstractions/Models/SearchResult.cs b/service/Abstractions/Models/SearchResult.cs
index 019720c22..4e2e35128 100644
--- a/service/Abstractions/Models/SearchResult.cs
+++ b/service/Abstractions/Models/SearchResult.cs
@@ -26,10 +26,7 @@ public class SearchResult
[JsonPropertyOrder(2)]
public bool NoResult
{
- get
- {
- return this.Results.Count == 0;
- }
+ get => this.Results == null || this.Results.Count == 0;
private set { }
}
@@ -40,6 +37,7 @@ private set { }
///
[JsonPropertyName("results")]
[JsonPropertyOrder(3)]
+ [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)]
public List Results { get; set; } = [];
///
diff --git a/service/Abstractions/Models/StreamStates.cs b/service/Abstractions/Models/StreamStates.cs
new file mode 100644
index 000000000..041187dd7
--- /dev/null
+++ b/service/Abstractions/Models/StreamStates.cs
@@ -0,0 +1,58 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System;
+using System.Text.Json;
+using System.Text.Json.Serialization;
+
+namespace Microsoft.KernelMemory;
+
+[JsonConverter(typeof(StreamStatesConverter))]
+public enum StreamStates
+{
+ // Inform the client the stream ended to an error.
+ Error = 0,
+
+ // When streaming, inform the client to discard any previous data
+ // and start collecting again using this record as the first one.
+ Reset = 1,
+
+ // When streaming, append the current result to the data
+ // already received so far.
+ Append = 2,
+
+ // Inform the client the end of the stream has been reached
+ // and that this is the last record to append.
+ Last = 3,
+}
+
+#pragma warning disable CA1308
+internal sealed class StreamStatesConverter : JsonConverter
+{
+ public override StreamStates Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ {
+ string value = reader.GetString()!;
+ return value.ToLowerInvariant() switch
+ {
+ "error" => StreamStates.Error,
+ "reset" => StreamStates.Reset,
+ "append" => StreamStates.Append,
+ "last" => StreamStates.Last,
+ _ => throw new JsonException($"Unknown {nameof(StreamStates)} value: {value}")
+ };
+ }
+
+ public override void Write(Utf8JsonWriter writer, StreamStates value, JsonSerializerOptions options)
+ {
+ string serializedValue = value switch
+ {
+ StreamStates.Error => "error",
+ StreamStates.Reset => "reset",
+ StreamStates.Append => "append",
+ StreamStates.Last => "last",
+ _ => throw new JsonException($"Cannot serialize {nameof(StreamStates)} value: {value}")
+ };
+
+ writer.WriteStringValue(serializedValue);
+ }
+}
+#pragma warning restore CA1308
diff --git a/service/Abstractions/Search/ISearchClient.cs b/service/Abstractions/Search/ISearchClient.cs
index a8da8f1cc..4329538e3 100644
--- a/service/Abstractions/Search/ISearchClient.cs
+++ b/service/Abstractions/Search/ISearchClient.cs
@@ -50,6 +50,24 @@ Task AskAsync(
IContext? context = null,
CancellationToken cancellationToken = default);
+ ///
+ /// Answer the given question, if possible, grounding the response with relevant memories matching the given criteria.
+ ///
+ /// Index (aka collection) to search for grounding information
+ /// Question to answer
+ /// Filtering criteria to select memories to consider
+ /// Minimum relevance of the memories considered
+ /// Optional context carrying optional information used by internal logic
+ /// Async task cancellation token
+ /// Answer to the given question
+ IAsyncEnumerable AskStreamingAsync(
+ string index,
+ string question,
+ ICollection? filters = null,
+ double minRelevance = 0,
+ IContext? context = null,
+ CancellationToken cancellationToken = default);
+
///
/// List the available memory indexes (aka collections).
///
diff --git a/service/Abstractions/Search/SearchOptions.cs b/service/Abstractions/Search/SearchOptions.cs
new file mode 100644
index 000000000..4f88b1f0c
--- /dev/null
+++ b/service/Abstractions/Search/SearchOptions.cs
@@ -0,0 +1,29 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+#pragma warning disable IDE0130 // reduce number of "using" statements
+// ReSharper disable once CheckNamespace - reduce number of "using" statements
+namespace Microsoft.KernelMemory;
+
+// TODO: move minRelevance to this class
+// TODO: move filter to this class
+// TODO: move filters to this class
+public sealed class SearchOptions
+{
+ ///
+ /// Whether to stream results back to the client
+ ///
+ public bool Stream { get; set; } = false;
+}
+
+public static class SearchOptionsExtensions
+{
+ public static SearchOptions? Clone(this SearchOptions? options)
+ {
+ if (options == null) { return null; }
+
+ return new SearchOptions
+ {
+ Stream = options.Stream
+ };
+ }
+}
diff --git a/service/Core/Configuration/ServiceConfig.cs b/service/Core/Configuration/ServiceConfig.cs
index 4b89ff129..2a9ab78ec 100644
--- a/service/Core/Configuration/ServiceConfig.cs
+++ b/service/Core/Configuration/ServiceConfig.cs
@@ -24,6 +24,11 @@ public class ServiceConfig
///
public bool OpenApiEnabled { get; set; } = false;
+ ///
+ /// Whether to send a [DONE] message at the end of SSE streams.
+ ///
+ public bool SendSSEDoneMessage { get; set; } = true;
+
///
/// List of handlers to enable
///
diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs
index 2aba4cf0d..240ddd52a 100644
--- a/service/Core/MemoryServerless.cs
+++ b/service/Core/MemoryServerless.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -262,30 +263,47 @@ public Task SearchAsync(
}
///
- public Task AskAsync(
+ public async IAsyncEnumerable AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection? filters = null,
double minRelevance = 0,
+ SearchOptions? options = null,
IContext? context = null,
- CancellationToken cancellationToken = default)
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
this._contextProvider.InitContext(context);
if (filter != null)
{
- if (filters == null) { filters = []; }
-
+ filters ??= [];
filters.Add(filter);
}
index = IndexName.CleanName(index, this._defaultIndexName);
- return this._searchClient.AskAsync(
+
+ if (options is { Stream: true })
+ {
+ await foreach (var answer in this._searchClient.AskStreamingAsync(
+ index: index,
+ question: question,
+ filters: filters,
+ minRelevance: minRelevance,
+ context: context,
+ cancellationToken).ConfigureAwait(false))
+ {
+ yield return answer;
+ }
+
+ yield break;
+ }
+
+ yield return await this._searchClient.AskAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
context: context,
- cancellationToken: cancellationToken);
+ cancellationToken).ConfigureAwait(false);
}
}
diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs
index 2ee34b163..5a88813da 100644
--- a/service/Core/MemoryService.cs
+++ b/service/Core/MemoryService.cs
@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -228,29 +229,46 @@ public Task SearchAsync(
}
///
- public Task AskAsync(
+ public async IAsyncEnumerable AskStreamingAsync(
string question,
string? index = null,
MemoryFilter? filter = null,
ICollection? filters = null,
double minRelevance = 0,
+ SearchOptions? options = null,
IContext? context = null,
- CancellationToken cancellationToken = default)
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
if (filter != null)
{
- if (filters == null) { filters = []; }
-
+ filters ??= [];
filters.Add(filter);
}
index = IndexName.CleanName(index, this._defaultIndexName);
- return this._searchClient.AskAsync(
+
+ if (options is { Stream: true })
+ {
+ await foreach (var answer in this._searchClient.AskStreamingAsync(
+ index: index,
+ question: question,
+ filters: filters,
+ minRelevance: minRelevance,
+ context: context,
+ cancellationToken).ConfigureAwait(false))
+ {
+ yield return answer;
+ }
+
+ yield break;
+ }
+
+ yield return await this._searchClient.AskAsync(
index: index,
question: question,
filters: filters,
minRelevance: minRelevance,
context: context,
- cancellationToken: cancellationToken);
+ cancellationToken).ConfigureAwait(false);
}
}
diff --git a/service/Core/Search/AnswerGenerator.cs b/service/Core/Search/AnswerGenerator.cs
index 14abb981b..352160851 100644
--- a/service/Core/Search/AnswerGenerator.cs
+++ b/service/Core/Search/AnswerGenerator.cs
@@ -2,8 +2,8 @@
using System;
using System.Collections.Generic;
-using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
+using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -44,86 +44,75 @@ public AnswerGenerator(
{
throw new KernelMemoryException("Text generator not configured");
}
+
+ if (this._contentModeration == null || !this._config.UseContentModeration)
+ {
+ this._log.LogInformation("Content moderation is not enabled.");
+ }
}
- internal async Task GenerateAnswerAsync(
- string question, SearchClientResult result, IContext? context, CancellationToken cancellationToken)
+ internal async IAsyncEnumerable GenerateAnswerAsync(
+ string question, SearchClientResult result,
+ IContext? context, [EnumeratorCancellation] CancellationToken cancellationToken)
{
if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0)
{
this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
- result.AskResult.NoResultReason = "Unable to use memories";
- return result.AskResult;
+ yield return result.InsufficientTokensResult;
+ yield break;
}
if (result.FactsUsedCount == 0)
{
this._log.LogWarning("No memories available");
- result.AskResult.NoResultReason = "No memories available";
- return result.AskResult;
+ yield return result.NoFactsResult;
+ yield break;
}
- // Collect the LLM output
- var text = new StringBuilder();
- var charsGenerated = 0;
- var watch = new Stopwatch();
- watch.Restart();
- await foreach (var x in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false))
+ var completeAnswer = new StringBuilder();
+ await foreach (var answerToken in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false))
{
- text.Append(x);
-
- if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30)
- {
- charsGenerated = text.Length;
- this._log.LogTrace("{0} chars generated", charsGenerated);
- }
+ completeAnswer.Append(answerToken);
+ result.AskResult.Result = answerToken;
+ yield return result.AskResult;
}
- watch.Stop();
-
// Finalize the answer, checking if it's empty
- result.AskResult.Result = text.ToString();
- this._log.LogSensitive("Answer: {0}", result.AskResult.Result);
- result.AskResult.NoResult = ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer);
- if (result.AskResult.NoResult)
- {
- result.AskResult.NoResultReason = "No relevant memories found";
- this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds);
- }
- else
+ result.AskResult.Result = completeAnswer.ToString();
+ if (string.IsNullOrWhiteSpace(result.AskResult.Result)
+ || ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer))
{
- this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
+ this._log.LogInformation("No relevant memories found, returning empty answer.");
+ yield return result.NoFactsResult;
+ yield break;
}
- // Validate the LLM output
- if (this._contentModeration != null && this._config.UseContentModeration)
+ this._log.LogSensitive("Answer: {0}", result.AskResult.Result);
+
+ if (this._config.UseContentModeration
+ && this._contentModeration != null
+ && !await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false))
{
- var isSafe = await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false);
- if (!isSafe)
- {
- this._log.LogWarning("Unsafe answer detected. Returning error message instead.");
- this._log.LogSensitive("Unsafe answer: {0}", result.AskResult.Result);
- result.AskResult.NoResultReason = "Content moderation failure";
- result.AskResult.Result = this._config.ModeratedAnswer;
- }
+ this._log.LogWarning("Unsafe answer detected. Returning error message instead.");
+ yield return result.UnsafeAnswerResult;
}
-
- return result.AskResult;
}
private IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken)
{
string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
+ string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer);
int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens);
double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature);
double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP);
- prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
-
question = question.Trim();
question = question.EndsWith('?') ? question : $"{question}?";
+
+ prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase);
- prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);
+ prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase);
+ this._log.LogInformation("New prompt: {0}", prompt);
var options = new TextGenerationOptions
{
diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs
index cf410ad54..f6b3ddae7 100644
--- a/service/Core/Search/SearchClient.cs
+++ b/service/Core/Search/SearchClient.cs
@@ -5,6 +5,8 @@
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
@@ -120,6 +122,81 @@ public async Task AskAsync(
double minRelevance = 0,
IContext? context = null,
CancellationToken cancellationToken = default)
+ {
+ var result = new MemoryAnswer();
+
+ var stream = this.AskStreamingAsync(
+ index: index, question: question, filters, minRelevance, context, cancellationToken)
+ .ConfigureAwait(false);
+
+ var done = false;
+ StringBuilder text = new(result.Result);
+ await foreach (var part in stream.ConfigureAwait(false))
+ {
+ if (done) { break; }
+
+ switch (part.StreamState)
+ {
+ case StreamStates.Error:
+ text.Clear();
+ result = part;
+
+ done = true;
+ break;
+
+ case StreamStates.Reset:
+ text.Clear();
+ text.Append(part.Result);
+ result = part;
+ break;
+
+ case StreamStates.Append:
+ result.NoResult = part.NoResult;
+ result.NoResultReason = part.NoResultReason;
+
+ text.Append(part.Result);
+ result.Result = text.ToString();
+
+ if (result.RelevantSources != null && part.RelevantSources != null)
+ {
+ result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList();
+ }
+
+ break;
+
+ case StreamStates.Last:
+ result.NoResult = part.NoResult;
+ result.NoResultReason = part.NoResultReason;
+
+ text.Append(part.Result);
+ result.Result = text.ToString();
+
+ if (result.RelevantSources != null && part.RelevantSources != null)
+ {
+ result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList();
+ }
+
+ done = true;
+ break;
+
+ default:
+ throw new ArgumentOutOfRangeException(nameof(part.StreamState));
+ }
+ }
+
+ result.Question = question;
+ result.StreamState = null;
+ return result;
+ }
+
+ ///
+ public async IAsyncEnumerable AskStreamingAsync(
+ string index,
+ string question,
+ ICollection? filters = null,
+ double minRelevance = 0,
+ IContext? context = null,
+ [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer);
string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt);
@@ -129,9 +206,11 @@ public async Task AskAsync(
? this._config.MaxAskPromptSize
: this._textGenerator.MaxTokenTotal;
+ // Prepare results (empty, error, etc.)
SearchClientResult result = SearchClientResult.AskResultInstance(
question: question,
emptyAnswer: emptyAnswer,
+ moderatedAnswer: this._config.ModeratedAnswer,
maxGroundingFacts: limit,
tokensAvailable: maxTokens
- this._textGenerator.CountTokens(answerPrompt)
@@ -142,7 +221,8 @@ public async Task AskAsync(
if (string.IsNullOrEmpty(question))
{
this._log.LogWarning("No question provided");
- return result.AskResult;
+ yield return result.NoQuestionResult;
+ yield break;
}
this._log.LogTrace("Fetching relevant memories");
@@ -171,7 +251,22 @@ public async Task AskAsync(
this._log.LogTrace("{Count} records processed", result.RecordCount);
- return await this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false);
+ var first = true;
+ await foreach (var answer in this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false))
+ {
+ yield return answer;
+
+ if (first)
+ {
+ // Remove redundant data, sent only once in the first record, to reduce payload
+ first = false;
+
+ // Note: we keep the sources in the other collections (e.g. AskResult.ErrorResult.RelevantSources),
+ // so in case of a stream reset the sources are sent again.
+ result.AskResult.RelevantSources.Clear();
+ result.AskResult.Question = null!;
+ }
+ }
}
///
@@ -248,29 +343,17 @@ private SearchClientResult ProcessMemoryRecord(
this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance);
}
- Citation? citation;
- if (result.Mode == SearchMode.SearchMode)
+ var citation = result.Mode switch
{
- citation = result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile);
- if (citation == null)
- {
- citation = new Citation();
- result.SearchResult.Results.Add(citation);
- }
- }
- else if (result.Mode == SearchMode.AskMode)
- {
- // If the file is already in the list of citations, only add the partition
- citation = result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile);
- if (citation == null)
- {
- citation = new Citation();
- result.AskResult.RelevantSources.Add(citation);
- }
- }
- else
+ SearchMode.SearchMode => result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile),
+ SearchMode.AskMode => result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile),
+ _ => throw new ArgumentOutOfRangeException(nameof(result.Mode))
+ };
+
+ if (citation == null)
{
- throw new ArgumentOutOfRangeException(nameof(result.Mode));
+ citation = new Citation();
+ result.AddSource(citation);
}
citation.Index = index;
diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs
index c509809a7..66ff58f15 100644
--- a/service/Core/Search/SearchClientResult.cs
+++ b/service/Core/Search/SearchClientResult.cs
@@ -23,10 +23,16 @@ internal class SearchClientResult
public SearchState State { get; set; }
public int RecordCount { get; set; }
- // Use by in Search and Ask mode
- public MemoryAnswer AskResult { get; private init; } = new();
+ // Use by Search and Ask mode
public int MaxRecordCount { get; private init; }
+ public MemoryAnswer AskResult { get; private init; } = new();
+ public MemoryAnswer NoFactsResult { get; private init; } = new();
+ public MemoryAnswer NoQuestionResult { get; private init; } = new();
+ public MemoryAnswer UnsafeAnswerResult { get; private init; } = new();
+ public MemoryAnswer InsufficientTokensResult { get; private init; } = new();
+ public MemoryAnswer ErrorResult { get; private init; } = new();
+
// Use by Ask mode
public SearchResult SearchResult { get; private init; } = new();
public StringBuilder Facts { get; } = new();
@@ -37,7 +43,9 @@ internal class SearchClientResult
///
/// Create new instance in Ask mode
///
- public static SearchClientResult AskResultInstance(string question, string emptyAnswer, int maxGroundingFacts, int tokensAvailable)
+ public static SearchClientResult AskResultInstance(
+ string question, string emptyAnswer, string moderatedAnswer,
+ int maxGroundingFacts, int tokensAvailable)
{
return new SearchClientResult
{
@@ -46,14 +54,64 @@ public static SearchClientResult AskResultInstance(string question, string empty
MaxRecordCount = maxGroundingFacts,
AskResult = new MemoryAnswer
{
+ StreamState = StreamStates.Append,
+ Question = question,
+ NoResult = false
+ },
+ NoFactsResult = new MemoryAnswer
+ {
+ StreamState = StreamStates.Reset,
+ Question = question,
+ NoResult = true,
+ NoResultReason = "No relevant memories available",
+ Result = emptyAnswer
+ },
+ NoQuestionResult = new MemoryAnswer
+ {
+ StreamState = StreamStates.Reset,
Question = question,
NoResult = true,
NoResultReason = "No question provided",
- Result = emptyAnswer,
+ Result = emptyAnswer
+ },
+ InsufficientTokensResult = new MemoryAnswer
+ {
+ StreamState = StreamStates.Reset,
+ Question = question,
+ NoResult = true,
+ NoResultReason = "Unable to use memory, max tokens reached",
+ Result = emptyAnswer
+ },
+ UnsafeAnswerResult = new MemoryAnswer
+ {
+ StreamState = StreamStates.Reset,
+ Question = question,
+ NoResult = true,
+ NoResultReason = "Content moderation",
+ Result = moderatedAnswer
+ },
+ ErrorResult = new MemoryAnswer
+ {
+ StreamState = StreamStates.Error,
+ Question = question,
+ NoResult = true,
+ NoResultReason = "An error occurred"
}
};
}
+ ///
+ /// Add source to all the collections
+ ///
+ public void AddSource(Citation citation)
+ {
+ this.SearchResult.Results?.Add(citation);
+ this.AskResult.RelevantSources?.Add(citation);
+ this.InsufficientTokensResult.RelevantSources?.Add(citation);
+ this.UnsafeAnswerResult.RelevantSources?.Add(citation);
+ this.ErrorResult.RelevantSources?.Add(citation);
+ }
+
///
/// Create new instance in Search mode
///
diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs
index cef2388b8..8e63c1aa3 100644
--- a/service/Service.AspNetCore/WebAPIEndpoints.cs
+++ b/service/Service.AspNetCore/WebAPIEndpoints.cs
@@ -3,6 +3,8 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
+using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
@@ -15,6 +17,7 @@
using Microsoft.KernelMemory.Configuration;
using Microsoft.KernelMemory.Context;
using Microsoft.KernelMemory.DocumentStorage;
+using Microsoft.KernelMemory.HTTP;
using Microsoft.KernelMemory.Service.AspNetCore.Models;
namespace Microsoft.KernelMemory.Service.AspNetCore;
@@ -31,11 +34,10 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints(
builder.AddGetIndexesEndpoint(apiPrefix).AddFilters(filters);
builder.AddDeleteIndexesEndpoint(apiPrefix).AddFilters(filters);
builder.AddDeleteDocumentsEndpoint(apiPrefix).AddFilters(filters);
- builder.AddAskEndpoint(apiPrefix).AddFilters(filters);
+ builder.AddAskEndpoint(apiPrefix, kmConfig?.Service.SendSSEDoneMessage ?? true).AddFilters(filters);
builder.AddSearchEndpoint(apiPrefix).AddFilters(filters);
builder.AddUploadStatusEndpoint(apiPrefix).AddFilters(filters);
builder.AddGetDownloadEndpoint(apiPrefix).AddFilters(filters);
-
return builder;
}
@@ -212,13 +214,14 @@ await service.DeleteDocumentAsync(documentId: documentId, index: index, cancella
}
public static RouteHandlerBuilder AddAskEndpoint(
- this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null)
+ this IEndpointRouteBuilder builder, string apiPrefix = "/", bool sseSendDoneMessage = true, IEndpointFilter[]? filters = null)
{
RouteGroupBuilder group = builder.MapGroup(apiPrefix);
// Ask endpoint
var route = group.MapPost(Constants.HttpAskEndpoint,
- async Task (
+ async Task (
+ HttpContext httpContext,
MemoryQuery query,
IKernelMemory service,
ILogger log,
@@ -228,20 +231,75 @@ async Task (
// Allow internal classes to access custom arguments via IContextProvider
contextProvider.InitContextArgs(query.ContextArguments);
- log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance);
- MemoryAnswer answer = await service.AskAsync(
- question: query.Question,
- index: query.Index,
- filters: query.Filters,
- minRelevance: query.MinRelevance,
- context: contextProvider.GetContext(),
- cancellationToken: cancellationToken)
- .ConfigureAwait(false);
- return Results.Ok(answer);
+ log.LogTrace("New ask request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance);
+
+ IAsyncEnumerable answerStream = service.AskStreamingAsync(
+ question: query.Question,
+ index: query.Index,
+ filters: query.Filters,
+ minRelevance: query.MinRelevance,
+ options: new SearchOptions { Stream = query.Stream },
+ context: contextProvider.GetContext(),
+ cancellationToken: cancellationToken);
+
+ httpContext.Response.StatusCode = StatusCodes.Status200OK;
+
+ try
+ {
+ if (query.Stream)
+ {
+ httpContext.Response.ContentType = "text/event-stream; charset=utf-8";
+ await foreach (var answer in answerStream.ConfigureAwait(false))
+ {
+ string json = answer.ToJson(true);
+ await httpContext.Response.WriteAsync($"{SSE.DataPrefix}{json}\n\n", cancellationToken).ConfigureAwait(false);
+ await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false);
+ }
+ }
+ else
+ {
+ httpContext.Response.ContentType = "application/json; charset=utf-8";
+ MemoryAnswer answer = await answerStream.FirstAsync(cancellationToken: cancellationToken).ConfigureAwait(false);
+ string json = answer.ToJson(false);
+ await httpContext.Response.WriteAsync(json, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ catch (Exception e)
+ {
+ log.LogError(e, "An error occurred while preparing the response");
+
+ // Attempt to set the status code, in case the output hasn't started yet
+ httpContext.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
+
+ var json = query.Stream
+ ? JsonSerializer.Serialize(new MemoryAnswer
+ {
+ StreamState = StreamStates.Error,
+ Question = query.Question,
+ NoResult = true,
+ NoResultReason = $"Error: {e.Message} [{e.GetType().FullName}]"
+ })
+ : JsonSerializer.Serialize(new ProblemDetails
+ {
+ Status = StatusCodes.Status503ServiceUnavailable,
+ Title = "Service Unavailable",
+ Detail = $"{e.Message} [{e.GetType().FullName}]"
+ });
+
+ await httpContext.Response.WriteAsync(query.Stream ? $"{SSE.DataPrefix}{json}\n\n" : json, cancellationToken).ConfigureAwait(false);
+ }
+
+ if (query.Stream && sseSendDoneMessage)
+ {
+ await httpContext.Response.WriteAsync($"{SSE.DoneMessage}\n\n", cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+
+ await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false);
})
.Produces(StatusCodes.Status200OK)
.Produces(StatusCodes.Status401Unauthorized)
- .Produces(StatusCodes.Status403Forbidden);
+ .Produces(StatusCodes.Status403Forbidden)
+ .Produces(StatusCodes.Status503ServiceUnavailable);
return route;
}
diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json
index 506f52c2d..db88ad35c 100644
--- a/service/Service/appsettings.json
+++ b/service/Service/appsettings.json
@@ -48,6 +48,8 @@
// If not set the solution defaults to 30,000,000 bytes (~28.6 MB)
// Note: this applies only to KM HTTP service.
"MaxUploadSizeMb": null,
+ // Whether to send a [DONE] message at the end of SSE streams.
+ "SendSSEDoneMessage": true,
// Whether to run the asynchronous pipeline handlers
// Use these booleans to deploy the web service and the handlers on same/different VMs
"RunHandlers": true,
diff --git a/service/tests/Abstractions.UnitTests/Http/SSETest.cs b/service/tests/Abstractions.UnitTests/Http/SSETest.cs
new file mode 100644
index 000000000..ffa463737
--- /dev/null
+++ b/service/tests/Abstractions.UnitTests/Http/SSETest.cs
@@ -0,0 +1,149 @@
+// Copyright (c) Microsoft. All rights reserved.
+
+using System.Text;
+using Microsoft.KernelMemory;
+using Microsoft.KernelMemory.HTTP;
+
+namespace Microsoft.KM.Abstractions.UnitTests.Http;
+
+public class SSETest
+{
+ [Theory]
+ [InlineData(null)]
+ [InlineData("")]
+ [InlineData(" ")]
+ [InlineData(" \n")]
+ public void ItParsesEmptyStrings(string input)
+ {
+ Assert.Null(SSE.ParseMessage(input));
+ }
+
+ [Fact]
+ public void ItParsesSingleLineMessage()
+ {
+ // Arrange
+ var message = """
+ data: { "question": "q" }
+ """;
+
+ // Act
+ var x = SSE.ParseMessage(message);
+
+ // Assert
+ Assert.NotNull(x);
+ Assert.Equal("q", x.Question);
+ }
+
+ [Fact]
+ public void ItParsesSingleLineMessageWithSeparator()
+ {
+ // Arrange
+ var message = """
+ data: { "question": "q" }
+
+
+ """;
+
+ // Act
+ var x = SSE.ParseMessage(message);
+
+ // Assert
+ Assert.NotNull(x);
+ Assert.Equal("q", x.Question);
+ }
+
+ [Fact]
+ public void ItParsesMultiLineMessage()
+ {
+ // Arrange
+ var message = """
+ data: { "question": "q"
+ data: , "noResultReason": "abc"
+ data: }
+ """;
+
+ // Act
+ var x = SSE.ParseMessage(message);
+
+ // Assert
+ Assert.NotNull(x);
+ Assert.Equal("q", x.Question);
+ Assert.Equal("abc", x.NoResultReason);
+ }
+
+ [Theory]
+ [InlineData("data: [DONE]")]
+ [InlineData("data: [DONE]\n")]
+ [InlineData("data: [DONE]\n\n")]
+ public async Task ItParsesEmptyStreams(string input)
+ {
+ // Arrange
+ using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input));
+
+ // Act
+ var result = SSE.ParseStreamAsync(stream);
+
+ // Assert
+ var messages = new List();
+ await foreach (var message in result)
+ {
+ messages.Add(message);
+ }
+
+ Assert.Equal(0, messages.Count);
+ }
+
+ [Theory]
+ [InlineData("data: { \"question\": \"qq\" }")]
+ [InlineData("data: { \"question\": \"qq\" }\n")]
+ [InlineData("data: { \"question\": \"qq\" }\n\n")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n\n")]
+ public async Task ItParsesStreamsWithASingleMessage(string input)
+ {
+ // Arrange
+ using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input));
+
+ // Act
+ var result = SSE.ParseStreamAsync(stream);
+
+ // Assert
+ var messages = new List();
+ await foreach (var message in result)
+ {
+ messages.Add(message);
+ }
+
+ Assert.Equal(1, messages.Count);
+ Assert.NotNull(messages[0]);
+ Assert.Equal("qq", messages[0].Question);
+ }
+
+ [Theory]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\n")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n")]
+ [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n\n")]
+ public async Task ItParsesStreamsWithMultipleMessage(string input)
+ {
+ // Arrange
+ using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input));
+
+ // Act
+ var result = SSE.ParseStreamAsync(stream);
+
+ // Assert
+ var messages = new List();
+ await foreach (var message in result)
+ {
+ messages.Add(message);
+ }
+
+ Assert.Equal(2, messages.Count);
+ Assert.NotNull(messages[0]);
+ Assert.Equal("qq", messages[0].Question);
+ Assert.NotNull(messages[1]);
+ Assert.Equal("kk", messages[1].Question);
+ }
+}