From f082f5da2b7920fb991f5bb695fd8b7f34f2d7b8 Mon Sep 17 00:00:00 2001 From: Carlo Date: Thu, 1 Aug 2024 15:37:57 +0200 Subject: [PATCH 01/30] added ask_stream endpoint you can now call azure openAI and OpenAI with streaming --- service/Abstractions/Constants.cs | 4 + service/Abstractions/Context/IContext.cs | 10 + service/Abstractions/IKernelMemory.cs | 9 + service/Abstractions/Search/ISearchClient.cs | 9 + service/Core/MemoryServerless.cs | 26 +++ service/Core/MemoryService.cs | 26 +++ service/Core/Search/SearchClient.cs | 176 +++++++++++++++++- service/Service.AspNetCore/WebAPIEndpoints.cs | 53 ++++++ service/Service/Program.cs | 7 +- 9 files changed, 315 insertions(+), 5 deletions(-) diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 77ea23e34..644df9964 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -48,6 +48,8 @@ public static class EmbeddingGeneration public static class Rag { + // Used to override EosToken config for streaming + public const string EosToken = "custom_rag_eos_token_str"; // Used to override No Answer config public const string EmptyAnswer = "custom_rag_empty_answer_str"; @@ -112,6 +114,8 @@ public static class Summary public const string ReservedPayloadVectorGeneratorField = "vector_generator"; // Endpoints + + public const string HttpAskChunkEndpoint = "/ask_stream"; public const string HttpAskEndpoint = "/ask"; public const string HttpSearchEndpoint = "/search"; public const string HttpDownloadEndpoint = "/download"; diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs index 4835fd794..c093c8cde 100644 --- a/service/Abstractions/Context/IContext.cs +++ b/service/Abstractions/Context/IContext.cs @@ -100,6 +100,16 @@ public static string GetCustomEmptyAnswerTextOrDefault(this IContext? context, s return defaultValue; } + public static string GetCustomEosTokenOrDefault(this IContext? context, string defaultValue) + { + if (context.TryGetArg(Constants.CustomContext.Rag.EosToken, out var customValue)) + { + return customValue; + } + + return defaultValue; + } + public static string GetCustomRagFactTemplateOrDefault(this IContext? context, string defaultValue) { if (context.TryGetArg(Constants.CustomContext.Rag.FactTemplate, out var customValue)) diff --git a/service/Abstractions/IKernelMemory.cs b/service/Abstractions/IKernelMemory.cs index 89dc57009..64e89b1b6 100644 --- a/service/Abstractions/IKernelMemory.cs +++ b/service/Abstractions/IKernelMemory.cs @@ -228,4 +228,13 @@ public Task AskAsync( double minRelevance = 0, IContext? context = null, CancellationToken cancellationToken = default); + + public Task> AskAsyncChunk( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/Search/ISearchClient.cs b/service/Abstractions/Search/ISearchClient.cs index a8da8f1cc..a7a37230c 100644 --- a/service/Abstractions/Search/ISearchClient.cs +++ b/service/Abstractions/Search/ISearchClient.cs @@ -50,6 +50,15 @@ Task AskAsync( IContext? context = null, CancellationToken cancellationToken = default); + + IAsyncEnumerable AskAsyncChunk( + string index, + string question, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default); + /// /// List the available memory indexes (aka collections). /// diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs index 1b0be3981..e9e23a115 100644 --- a/service/Core/MemoryServerless.cs +++ b/service/Core/MemoryServerless.cs @@ -277,4 +277,30 @@ public Task AskAsync( context: context, cancellationToken: cancellationToken); } + + public Task> AskAsyncChunk( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default) + { + if (filter != null) + { + if (filters == null) { filters = new List(); } + + filters.Add(filter); + } + + index = IndexName.CleanName(index, this._defaultIndexName); + return Task.FromResult>(this._searchClient.AskAsyncChunk( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken: cancellationToken)); + } } diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs index 169cb45b6..f7898beb1 100644 --- a/service/Core/MemoryService.cs +++ b/service/Core/MemoryService.cs @@ -253,4 +253,30 @@ public Task AskAsync( context: context, cancellationToken: cancellationToken); } + + public Task> AskAsyncChunk( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default) + { + if (filter != null) + { + if (filters == null) { filters = new List(); } + + filters.Add(filter); + } + + index = IndexName.CleanName(index, this._defaultIndexName); + return Task.FromResult>(this._searchClient.AskAsyncChunk( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken: cancellationToken)); + } } diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 4a229f22e..65d7350af 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -8,6 +8,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; +using System.Runtime.CompilerServices; using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.AI; using Microsoft.KernelMemory.Context; @@ -336,7 +337,6 @@ public async Task AskAsync( await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false)) { text.Append(x); - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) { charsGenerated = text.Length; @@ -361,6 +361,179 @@ public async Task AskAsync( return answer; } + public async IAsyncEnumerable AskAsyncChunk( + string index, + string question, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); + string eosToken = context.GetCustomEosTokenOrDefault("end"); + string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); + string factTemplate = context.GetCustomRagFactTemplateOrDefault(this._config.FactTemplate); + if (!factTemplate.EndsWith('\n')) { factTemplate += "\n"; } + + var noAnswerFound = new MemoryAnswer + { + Question = question, + NoResult = true, + Result = emptyAnswer, + }; + + if (string.IsNullOrEmpty(question)) + { + this._log.LogWarning("No question provided"); + noAnswerFound.NoResultReason = "No question provided"; + yield return noAnswerFound; + yield break; + } + + var facts = new StringBuilder(); + var maxTokens = this._config.MaxAskPromptSize > 0 + ? this._config.MaxAskPromptSize + : this._textGenerator.MaxTokenTotal; + var tokensAvailable = maxTokens + - this._textGenerator.CountTokens(answerPrompt) + - this._textGenerator.CountTokens(question) + - this._config.AnswerTokens; + + var factsUsedCount = 0; + var factsAvailableCount = 0; + var answer = noAnswerFound; + + this._log.LogTrace("Fetching relevant memories"); + IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( + index: index, + text: question, + filters: filters, + minRelevance: minRelevance, + limit: this._config.MaxMatchesCount, + withEmbeddings: false, + cancellationToken: cancellationToken); + + // Memories are sorted by relevance, starting from the most relevant + await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false)) + { + // Note: a document can be composed by multiple files + string documentId = memory.GetDocumentId(this._log); + + // Identify the file in case there are multiple files + string fileId = memory.GetFileId(this._log); + + // Note: this is not a URL and perhaps could be dropped. For now it acts as a unique identifier. See also SourceUrl. + string linkToFile = $"{index}/{documentId}/{fileId}"; + + string fileName = memory.GetFileName(this._log); + + string webPageUrl = memory.GetWebPageUrl(index); + + var partitionText = memory.GetPartitionText(this._log).Trim(); + if (string.IsNullOrEmpty(partitionText)) + { + this._log.LogError("The document partition is empty, doc: {0}", memory.Id); + continue; + } + + factsAvailableCount++; + + var fact = PromptUtils.RenderFactTemplate( + template: factTemplate, + factContent: partitionText, + source: (fileName == "content.url" ? webPageUrl : fileName), + relevance: relevance.ToString("P1", CultureInfo.CurrentCulture), + recordId: memory.Id, + tags: memory.Tags, + metadata: memory.Payload); + + // Use the partition/chunk only if there's room for it + var size = this._textGenerator.CountTokens(fact); + if (size >= tokensAvailable) + { + // Stop after reaching the max number of tokens + break; + } + + factsUsedCount++; + this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance); + + facts.Append(fact); + tokensAvailable -= size; + + // If the file is already in the list of citations, only add the partition + var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); + if (citation == null) + { + citation = new Citation(); + answer.RelevantSources.Add(citation); + } + + // Add the partition to the list of citations + citation.Index = index; + citation.DocumentId = documentId; + citation.FileId = fileId; + citation.Link = linkToFile; + citation.SourceContentType = memory.GetFileContentType(this._log); + citation.SourceName = fileName; + citation.SourceUrl = memory.GetWebPageUrl(index); + + citation.Partitions.Add(new Citation.Partition + { + Text = partitionText, + Relevance = (float)relevance, + PartitionNumber = memory.GetPartitionNumber(this._log), + SectionNumber = memory.GetSectionNumber(), + LastUpdate = memory.GetLastUpdate(), + Tags = memory.Tags, + }); + + // In cases where a buggy storage connector is returning too many records + if (factsUsedCount >= this._config.MaxMatchesCount) + { + break; + } + } + + if (factsAvailableCount > 0 && factsUsedCount == 0) + { + this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); + noAnswerFound.NoResultReason = "Unable to use memories"; + yield return noAnswerFound; + yield break; + } + + if (factsUsedCount == 0) + { + this._log.LogWarning("No memories available"); + noAnswerFound.NoResultReason = "No memories available"; + yield return noAnswerFound; + yield break; + } + var charsGenerated = 0; + await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(true)) + { + var text = new StringBuilder(); + text.Append(x); + if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) + { + charsGenerated = text.Length; + this._log.LogTrace("{0} chars generated", charsGenerated); + } + var newAnswer = new MemoryAnswer + { + Question = question, + NoResult = false, + Result = text.ToString() + }; + this._log.LogInformation("Chunk: '{0}", newAnswer.Result); + yield return newAnswer; + } + answer.Result = eosToken; + this._log.LogInformation("Eos token: '{0}", answer.Result); + yield return answer; + } + private IAsyncEnumerable GenerateAnswer(string question, string facts, IContext? context, CancellationToken token) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); @@ -371,7 +544,6 @@ private IAsyncEnumerable GenerateAnswer(string question, string facts, I prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); question = question.Trim(); - question = question.EndsWith('?') ? question : $"{question}?"; prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index eeb431c34..a5ebc4e56 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -5,6 +5,7 @@ using System.IO; using System.Threading; using System.Threading.Tasks; +using System.Text.Json; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.HttpResults; @@ -29,6 +30,7 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints( builder.AddDeleteIndexesEndpoint(apiPrefix, authFilter); builder.AddDeleteDocumentsEndpoint(apiPrefix, authFilter); builder.AddAskEndpoint(apiPrefix, authFilter); + builder.AddAskChunkEndpoint(apiPrefix, authFilter); builder.AddSearchEndpoint(apiPrefix, authFilter); builder.AddUploadStatusEndpoint(apiPrefix, authFilter); builder.AddGetDownloadEndpoint(apiPrefix, authFilter); @@ -234,7 +236,58 @@ async Task ( if (authFilter != null) { route.AddEndpointFilter(authFilter); } } + public static void AddAskChunkEndpoint( + this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null) + { + RouteGroupBuilder group = builder.MapGroup(apiPrefix); + + // Ask endpoint + var route = group.MapPost(Constants.HttpAskChunkEndpoint, + async Task ( + HttpContext httpContext, + MemoryQuery query, + IKernelMemory service, + ILogger log, + IContextProvider contextProvider, + CancellationToken cancellationToken) => + { + // Allow internal classes to access custom arguments via IContextProvider + contextProvider.InitContextArgs(query.ContextArguments); + httpContext.Response.ContentType = "text/event-stream"; + log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); + try + { + var answerStream = await service.AskAsyncChunk( + question: query.Question, + index: query.Index, + filters: query.Filters, + minRelevance: query.MinRelevance, + context: contextProvider.GetContext(), + cancellationToken: cancellationToken + ).ConfigureAwait(false); + + await foreach (var answer in answerStream.ConfigureAwait(false)) + { + string json = JsonSerializer.Serialize(answer); + await httpContext.Response.WriteAsync($"data: {json}\n\n", cancellationToken).ConfigureAwait(false); + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + catch (Exception ex) + { + log.LogError(ex, "Error occurred while streaming resp"); + string errorJson = JsonSerializer.Serialize(new { error = "An error occurred while processing the request." }); + await httpContext.Response.WriteAsync($"data: {errorJson}\n\n", cancellationToken).ConfigureAwait(false); + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + }) + .Produces(StatusCodes.Status204NoContent) + .Produces(StatusCodes.Status401Unauthorized) + .Produces(StatusCodes.Status403Forbidden) + .RequireCors("KM-CORS"); + if (authFilter != null) { route.AddEndpointFilter(authFilter); } + } public static void AddSearchEndpoint( this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null) { diff --git a/service/Service/Program.cs b/service/Service/Program.cs index afbf214e1..67941464a 100644 --- a/service/Service/Program.cs +++ b/service/Service/Program.cs @@ -98,7 +98,7 @@ public static void Main(string[] args) }); // CORS - bool enableCORS = false; + bool enableCORS = true; const string CORSPolicyName = "KM-CORS"; if (enableCORS && config.Service.RunWebService) { @@ -107,10 +107,11 @@ public static void Main(string[] args) options.AddPolicy(name: CORSPolicyName, policy => { policy + //.WithOrigins("http://127.0.0.1:5500") these three are to test in local for js and html, supposing you're opening 5500 port + //.AllowAnyHeader() + //.AllowAnyMethod() .WithMethods("HEAD", "GET", "POST", "PUT", "DELETE") .WithExposedHeaders("Content-Type", "Content-Length", "Last-Modified"); - // .AllowAnyOrigin() - // .WithOrigins(...) // .AllowAnyHeader() // .WithHeaders(...) }); From 274b3f89f77c5438e552d183f6ab8e8afb8b49e7 Mon Sep 17 00:00:00 2001 From: carlodek Date: Thu, 1 Aug 2024 17:03:21 +0200 Subject: [PATCH 02/30] added streaming client --- clients/dotnet/WebClient/MemoryWebClient.cs | 51 +++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs index 42f5fd73c..fe0a9a780 100644 --- a/clients/dotnet/WebClient/MemoryWebClient.cs +++ b/clients/dotnet/WebClient/MemoryWebClient.cs @@ -11,6 +11,7 @@ using System.Text.Json; using System.Threading; using System.Threading.Tasks; +using System.Runtime.CompilerServices; using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Internals; @@ -371,6 +372,56 @@ public async Task AskAsync( return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); } + /// + public async Task> AskAsyncChunk( + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + IContext? context = null, + CancellationToken cancellationToken = default) + { + if (filter != null) + { + if (filters == null) { filters = new List(); } + + filters.Add(filter); + } + + MemoryQuery request = new() + { + Index = index, + Question = question, + Filters = (filters is { Count: > 0 }) ? filters.ToList() : new(), + MinRelevance = minRelevance, + ContextArguments = (context?.Arguments ?? new Dictionary()).ToDictionary(), + }; + using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json"); + + var url = Constants.HttpAskChunkEndpoint.CleanUrlPath(); + HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false); + response.EnsureSuccessStatusCode(); + var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + using (var reader = new StreamReader(stream, Encoding.UTF8)) + { + string? line; + while ((line = await reader.ReadLineAsync().ConfigureAwait(false)) != null) + { + if (line.StartsWith("data:")) + { + var jsonData = line.Substring(6); + var chunk = JsonSerializer.Deserialize(jsonData); + yield return chunk; + if (chunk?.Result == "end") + { + break; + } + } + } + } + } + #region private private static (string contentType, long contentLength, DateTimeOffset lastModified) GetFileDetails(HttpResponseMessage response) From 3edabdb4ac92e49f823ea83f3fa1b05f3fed6005 Mon Sep 17 00:00:00 2001 From: carlodek Date: Wed, 7 Aug 2024 10:49:43 +0200 Subject: [PATCH 03/30] added qdrant --- Directory.Packages.props | 1 + service/Service/Service.csproj | 1 + 2 files changed, 2 insertions(+) diff --git a/Directory.Packages.props b/Directory.Packages.props index 5995c57ed..0b42c609b 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -33,6 +33,7 @@ + diff --git a/service/Service/Service.csproj b/service/Service/Service.csproj index bc65d41b4..671452cd8 100644 --- a/service/Service/Service.csproj +++ b/service/Service/Service.csproj @@ -17,6 +17,7 @@ + From 488b8d50b9e6404b264114c1fda53ac206201048 Mon Sep 17 00:00:00 2001 From: carlodek Date: Fri, 9 Aug 2024 16:53:55 +0200 Subject: [PATCH 04/30] docker for us --- .dockerignore | 2 +- Dockerfile | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/.dockerignore b/.dockerignore index 6b77549bb..a507f0414 100644 --- a/.dockerignore +++ b/.dockerignore @@ -11,4 +11,4 @@ docker-compose.*.yml # To make sure that the local appsettings files are not copied when building the image locally. # These files are not in GitHub thanks to .gitignore -**/*appsettings.*.json + diff --git a/Dockerfile b/Dockerfile index 166da8f10..11274a0d3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,7 +13,7 @@ ARG PLATFORM=$BUILDPLATFORM # ARG BUILDPLATFORM FROM --platform=$PLATFORM mcr.microsoft.com/dotnet/sdk:$BUILD_IMAGE_TAG AS build -ARG BUILD_CONFIGURATION=Release +ARG BUILD_CONFIGURATION=Debug COPY . /src/ WORKDIR "/src/service/Service" @@ -46,15 +46,14 @@ COPY --from=build --chown=km:km --chmod=0550 /app/publish . ######################################################################### LABEL org.opencontainers.image.authors="Devis Lucato, https://github.com/dluc" -MAINTAINER Devis Lucato "https://github.com/dluc" +LABEL MAINTAINER="Carlo De Chellis" # Define current user USER $USER # Used by .NET and KM to load appsettings.Production.json -ENV ASPNETCORE_ENVIRONMENT Production -ENV ASPNETCORE_URLS http://+:9001 -ENV ASPNETCORE_HTTP_PORTS 9001 +ENV ASPNETCORE_ENVIRONMENT=Development +ENV ASPNETCORE_HTTP_PORTS=9001 EXPOSE 9001 From 5c549163ceb43ed73d19131dcb8ad594115897a3 Mon Sep 17 00:00:00 2001 From: carlodek Date: Mon, 28 Oct 2024 15:22:51 +0100 Subject: [PATCH 05/30] added gitignore for AzureOpenAI --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 2b68e0885..e51731407 100644 --- a/.gitignore +++ b/.gitignore @@ -503,3 +503,4 @@ swa-cli.config.json **/copilot-chat-app/webapp/node_modules *.orig +service/Service/AzureOpenAiDevelopment.json From 01f1fdc8c43f9275bcd0aea7b93443f57abe9d55 Mon Sep 17 00:00:00 2001 From: carlodek Date: Tue, 29 Oct 2024 16:48:54 +0100 Subject: [PATCH 06/30] custom notFound response implemented correctly --- service/Core/Search/SearchClient.cs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 1640dce15..08b302635 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -341,7 +341,7 @@ public async Task AskAsync( var charsGenerated = 0; var watch = new Stopwatch(); watch.Restart(); - await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + await foreach (var x in this.GenerateAnswer(question, emptyAnswer, facts.ToString(), context, cancellationToken).ConfigureAwait(false)) { text.Append(x); if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) @@ -531,7 +531,7 @@ public async IAsyncEnumerable AskAsyncChunk( yield break; } var charsGenerated = 0; - await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(true)) + await foreach (var x in this.GenerateAnswer(question, emptyAnswer,facts.ToString(), context, cancellationToken).ConfigureAwait(true)) { var text = new StringBuilder(); text.Append(x); @@ -554,7 +554,7 @@ public async IAsyncEnumerable AskAsyncChunk( yield return answer; } - private IAsyncEnumerable GenerateAnswer(string question, string facts, IContext? context, CancellationToken token) + private IAsyncEnumerable GenerateAnswer(string question, string emptyAnswer, string facts, IContext? context, CancellationToken token) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); @@ -565,7 +565,7 @@ private IAsyncEnumerable GenerateAnswer(string question, string facts, I question = question.Trim(); prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); - prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); + prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase); var options = new TextGenerationOptions { From dcb54bac82da98c74048fc2eb947e0fa256f5faa Mon Sep 17 00:00:00 2001 From: carlodek Date: Mon, 4 Nov 2024 16:40:46 +0100 Subject: [PATCH 07/30] added response if no memory is found --- service/Core/Search/SearchClient.cs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 08b302635..e4bd8be6b 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -520,16 +520,13 @@ public async IAsyncEnumerable AskAsyncChunk( this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); noAnswerFound.NoResultReason = "Unable to use memories"; yield return noAnswerFound; + answer.Result = eosToken; + yield return answer; yield break; } if (factsUsedCount == 0) - { this._log.LogWarning("No memories available"); - noAnswerFound.NoResultReason = "No memories available"; - yield return noAnswerFound; - yield break; - } var charsGenerated = 0; await foreach (var x in this.GenerateAnswer(question, emptyAnswer,facts.ToString(), context, cancellationToken).ConfigureAwait(true)) { @@ -562,7 +559,7 @@ private IAsyncEnumerable GenerateAnswer(string question, string emptyAns double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); - + this._log.LogInformation("prompt: {0}", prompt); question = question.Trim(); prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase); From 5b926ce6dfbca92ad154832fe776849685ac1bd2 Mon Sep 17 00:00:00 2001 From: carlodek Date: Wed, 6 Nov 2024 11:06:06 +0100 Subject: [PATCH 08/30] content moderation --- service/Core/Search/SearchClient.cs | 37 +++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index e4bd8be6b..03b12f30f 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -368,16 +368,12 @@ public async Task AskAsync( if (this._contentModeration != null && this._config.UseContentModeration) { - var isSafe = await this._contentModeration.IsSafeAsync(answer.Result, cancellationToken).ConfigureAwait(false); - if (!isSafe) - { - this._log.LogWarning("Unsafe answer detected. Returning error message instead."); - this._log.LogSensitive("Unsafe answer: {0}", answer.Result); - answer.NoResultReason = "Content moderation failure"; + var moderatedAnswer = await this.isSafe(question, answer.Result, cancellationToken).ConfigureAwait(false); + if(moderatedAnswer.NoResultReason == "Content moderation failure"){ + answer.NoResultReason = moderatedAnswer.NoResultReason; answer.Result = this._config.ModeratedAnswer; } } - return answer; } @@ -528,10 +524,12 @@ public async IAsyncEnumerable AskAsyncChunk( if (factsUsedCount == 0) this._log.LogWarning("No memories available"); var charsGenerated = 0; + var wholeText = new StringBuilder(); await foreach (var x in this.GenerateAnswer(question, emptyAnswer,facts.ToString(), context, cancellationToken).ConfigureAwait(true)) { var text = new StringBuilder(); text.Append(x); + wholeText.Append(x); if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) { charsGenerated = text.Length; @@ -546,11 +544,36 @@ public async IAsyncEnumerable AskAsyncChunk( this._log.LogInformation("Chunk: '{0}", newAnswer.Result); yield return newAnswer; } + if (this._contentModeration != null && this._config.UseContentModeration) + { + var moderatedAnswer = await this.isSafe(question, wholeText.ToString(), cancellationToken).ConfigureAwait(false); + if(moderatedAnswer.NoResultReason == "Content moderation failure") + yield return moderatedAnswer; + } answer.Result = eosToken; this._log.LogInformation("Eos token: '{0}", answer.Result); yield return answer; } + private async Task isSafe(string question, string answerToDetect, [EnumeratorCancellation] CancellationToken cancellationToken) + { + var isSafe = await this._contentModeration.IsSafeAsync(answerToDetect, cancellationToken).ConfigureAwait(false); + var newAnswer = new MemoryAnswer + { + Question = question, + NoResult = false, + Result = answerToDetect + }; + if (!isSafe) + { + this._log.LogWarning("Unsafe answer detected. Returning error message instead."); + this._log.LogSensitive("Unsafe answer: {0}", answerToDetect); + newAnswer.NoResultReason = "Content moderation failure"; + newAnswer.Result = this._config.ModeratedAnswer; + } + return newAnswer; + } + private IAsyncEnumerable GenerateAnswer(string question, string emptyAnswer, string facts, IContext? context, CancellationToken token) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); From 202cecfd3a402478db3e60d86c022c40d7c285d9 Mon Sep 17 00:00:00 2001 From: carlodek Date: Mon, 11 Nov 2024 12:56:27 +0100 Subject: [PATCH 09/30] removed circular import --- service/Core/Core.csproj | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/service/Core/Core.csproj b/service/Core/Core.csproj index 2940e888d..c60d5bdb5 100644 --- a/service/Core/Core.csproj +++ b/service/Core/Core.csproj @@ -10,21 +10,6 @@ - - - - - - - - - - - - - - - From 7158bb86433f6e45d18d722e0a5864c01459dc3b Mon Sep 17 00:00:00 2001 From: carlodek Date: Thu, 28 Nov 2024 16:18:09 +0100 Subject: [PATCH 10/30] changed docker file --- Dockerfile | 1 - 1 file changed, 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c76f7f011..d373d20a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,6 @@ COPY --from=build --chown=km:km --chmod=0550 /app/publish . ######################################################################### LABEL org.opencontainers.image.authors="Devis Lucato, https://github.com/dluc" -LABEL MAINTAINER="Carlo De Chellis" # Define current user USER $USER From f9d9eaad99900ef4de97474ce5888f369b6b46b8 Mon Sep 17 00:00:00 2001 From: carlodek Date: Thu, 28 Nov 2024 17:15:21 +0100 Subject: [PATCH 11/30] code cleaning --- service/Core/MemoryService.cs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs index 420804d37..2ddf6b433 100644 --- a/service/Core/MemoryService.cs +++ b/service/Core/MemoryService.cs @@ -265,7 +265,10 @@ public IAsyncEnumerable AskAsyncChunk( { if (filter != null) { - if (filters == null) { filters = new List(); } + if (filters == null) + { + filters = new List(); + } filters.Add(filter); } From 1afbb85152de2fcdbca07ef0eb349f5f2d1d5352 Mon Sep 17 00:00:00 2001 From: carlodek Date: Thu, 28 Nov 2024 17:36:55 +0100 Subject: [PATCH 12/30] cleaned code for build automation --- clients/dotnet/WebClient/MemoryWebClient.cs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs index 1de4d58dd..d4e157cf4 100644 --- a/clients/dotnet/WebClient/MemoryWebClient.cs +++ b/clients/dotnet/WebClient/MemoryWebClient.cs @@ -403,19 +403,23 @@ public async IAsyncEnumerable AskAsyncChunk( HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + var eosToken = context.GetCustomEosTokenOrDefault("#DONE#"); using (var reader = new StreamReader(stream, Encoding.UTF8)) { string? line; - while ((line = await reader.ReadLineAsync().ConfigureAwait(false)) != null) + while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null) { - if (line.StartsWith("data:")) + if (line.StartsWith("data:", StringComparison.Ordinal)) { var jsonData = line.Substring(6); var chunk = JsonSerializer.Deserialize(jsonData); - yield return chunk; - if (chunk?.Result == "end") + if (chunk != null) { - break; + yield return chunk; + if (chunk.Result == eosToken) + { + break; + } } } } From b5978f7bd59b4fec413d431cc76161293d7bb85e Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:28:29 -0800 Subject: [PATCH 13/30] Update .dockerignore --- .dockerignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.dockerignore b/.dockerignore index 209984e35..ed5e212e7 100644 --- a/.dockerignore +++ b/.dockerignore @@ -16,4 +16,4 @@ infra # To make sure that the local appsettings files are not copied when building the image locally. # These files are not in GitHub thanks to .gitignore - +**/*appsettings.*.json From e727099fa6acf26fe5c174fc68b3980bb21988aa Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:28:46 -0800 Subject: [PATCH 14/30] Update .gitignore --- .gitignore | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitignore b/.gitignore index f110b6e46..34fc91e88 100644 --- a/.gitignore +++ b/.gitignore @@ -504,4 +504,3 @@ swa-cli.config.json *.orig .azure -service/Service/AzureOpenAiDevelopment.json From 4523be17821fb7263e724a7813884cf2ca19bc82 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:29:05 -0800 Subject: [PATCH 15/30] Update Directory.Packages.props --- Directory.Packages.props | 2 -- 1 file changed, 2 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index 4122a1d78..3a9b51c7c 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -29,8 +29,6 @@ - - From 008b50e8cc3dff806fb75755c3c3d83fc57b9e2e Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:30:50 -0800 Subject: [PATCH 16/30] Update Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index d373d20a4..2950d367b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,7 +11,7 @@ ARG RUN_IMAGE_TAG="8.0-alpine" FROM mcr.microsoft.com/dotnet/sdk:$BUILD_IMAGE_TAG AS build -ARG BUILD_CONFIGURATION=Debug +ARG BUILD_CONFIGURATION=Release COPY . /src/ WORKDIR "/src/service/Service" From 3d7118e92b0c9eb90d8f9fdbd0fc602886c8eddb Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:31:08 -0800 Subject: [PATCH 17/30] Update Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 2950d367b..b6ed9f7f9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,7 +49,7 @@ LABEL org.opencontainers.image.authors="Devis Lucato, https://github.com/dluc" USER $USER # Used by .NET and KM to load appsettings.Production.json -ENV ASPNETCORE_ENVIRONMENT=Development +ENV ASPNETCORE_ENVIRONMENT=Production ENV ASPNETCORE_HTTP_PORTS=9001 EXPOSE 9001 From f9eeb51102082505b630b6f7fa54aa2363d9091e Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:32:04 -0800 Subject: [PATCH 18/30] Update Dockerfile --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index b6ed9f7f9..9c50b5ca2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,6 +50,7 @@ USER $USER # Used by .NET and KM to load appsettings.Production.json ENV ASPNETCORE_ENVIRONMENT=Production +ENV ASPNETCORE_URLS=http://+:9001 ENV ASPNETCORE_HTTP_PORTS=9001 EXPOSE 9001 From 9079d283a3bfa1b4ba1fc9774188714895396269 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:32:54 -0800 Subject: [PATCH 19/30] Update service/Core/Core.csproj --- service/Core/Core.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Core/Core.csproj b/service/Core/Core.csproj index 5b94743a3..1c5c69eaf 100644 --- a/service/Core/Core.csproj +++ b/service/Core/Core.csproj @@ -14,7 +14,7 @@ - + From 11ce13cc3da7899e764c03d77ab92245627a44ca Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:33:44 -0800 Subject: [PATCH 20/30] Update service/Core/Search/SearchClient.cs --- service/Core/Search/SearchClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index eff836552..50c96bff4 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -347,7 +347,7 @@ private SearchClientResult ProcessMemoryRecord( string fact = PromptUtils.RenderFactTemplate( template: factTemplate!, factContent: partitionText, - source: fileNameForLlm, + source: fileNameForLLM, relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture), recordId: record.Id, tags: record.Tags, From 4ce9571c908ab11db1c5b7bf13f65762544e2389 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:33:53 -0800 Subject: [PATCH 21/30] Update service/Core/Search/SearchClient.cs --- service/Core/Search/SearchClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 50c96bff4..1b2d0bce5 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -333,7 +333,7 @@ private SearchClientResult ProcessMemoryRecord( string fileDownloadUrl = record.GetWebPageUrl(index); // Name of the file to show to the LLM, avoiding "content.url" - string fileNameForLlm = (fileName == "content.url" ? fileDownloadUrl : fileName); + string fileNameForLLM = (fileName == "content.url" ? fileDownloadUrl : fileName); if (result.Mode == SearchMode.SearchMode) { From 2656f2ef3ecd017d1164bc26258ab1f4c2948aa4 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:34:52 -0800 Subject: [PATCH 22/30] Update service/Service/Program.cs --- service/Service/Program.cs | 3 --- 1 file changed, 3 deletions(-) diff --git a/service/Service/Program.cs b/service/Service/Program.cs index 5c819af88..64c50bbbf 100644 --- a/service/Service/Program.cs +++ b/service/Service/Program.cs @@ -128,9 +128,6 @@ public static void Main(string[] args) options.AddPolicy(name: CORSPolicyName, policy => { policy - //.WithOrigins("http://127.0.0.1:5500") these three are to test in local for js and html, supposing you're opening 5500 port - //.AllowAnyHeader() - //.AllowAnyMethod() .WithMethods("HEAD", "GET", "POST", "PUT", "DELETE") .WithExposedHeaders("Content-Type", "Content-Length", "Last-Modified"); // .AllowAnyHeader() From dfc1ebc6a1554f067621b8803c1d01e9a363c9bb Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:35:20 -0800 Subject: [PATCH 23/30] Update service/Service/Program.cs --- service/Service/Program.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Service/Program.cs b/service/Service/Program.cs index 64c50bbbf..54e9ef379 100644 --- a/service/Service/Program.cs +++ b/service/Service/Program.cs @@ -119,7 +119,7 @@ public static void Main(string[] args) }); // CORS - bool enableCORS = true; + bool enableCORS = false; const string CORSPolicyName = "KM-CORS"; if (enableCORS && config.Service.RunWebService) { From a4341367ea752a63d63082376204bf1a29e03650 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:35:59 -0800 Subject: [PATCH 24/30] Update service/Service/Program.cs --- service/Service/Program.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/Service/Program.cs b/service/Service/Program.cs index 54e9ef379..4be864ef4 100644 --- a/service/Service/Program.cs +++ b/service/Service/Program.cs @@ -130,6 +130,8 @@ public static void Main(string[] args) policy .WithMethods("HEAD", "GET", "POST", "PUT", "DELETE") .WithExposedHeaders("Content-Type", "Content-Length", "Last-Modified"); + // .AllowAnyOrigin() + // .WithOrigins(...) // .AllowAnyHeader() // .WithHeaders(...) }); From b2485d89a5076fd0e7491c88c1d2ef9bf11cc444 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:36:17 -0800 Subject: [PATCH 25/30] Update service/Core/Search/SearchClient.cs --- service/Core/Search/SearchClient.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 1b2d0bce5..e316d5502 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -323,7 +323,7 @@ private SearchClientResult ProcessMemoryRecord( // Identify the file in case there are multiple files string fileId = record.GetFileId(this._log); - // Note: this is not a URL and perhaps could be dropped. For now, it acts as a unique identifier. See also SourceUrl. + // Note: this is not a URL and perhaps could be dropped. For now it acts as a unique identifier. See also SourceUrl. string linkToFile = $"{index}/{documentId}/{fileId}"; // Note: this is "content.url" when importing web pages From f16e9b322663fd1077baf99db791f0d4a1d827d3 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Thu, 28 Nov 2024 18:38:54 -0800 Subject: [PATCH 26/30] Update service/Service.AspNetCore/WebAPIEndpoints.cs --- service/Service.AspNetCore/WebAPIEndpoints.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index 1f9f1eff5..a9e607d1b 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -263,7 +263,7 @@ async Task ( { // Allow internal classes to access custom arguments via IContextProvider contextProvider.InitContextArgs(query.ContextArguments); - httpContext.Response.ContentType = "text/event-stream"; + httpContext.Response.ContentType = "text/event-stream; charset=utf-8"; log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); try { From 44a12bffa7c716cc9fb96f88439c4715d84a0909 Mon Sep 17 00:00:00 2001 From: carlodek <56030624+carlodek@users.noreply.github.com> Date: Fri, 29 Nov 2024 09:15:00 +0100 Subject: [PATCH 27/30] Update SearchClient.cs changed x in token --- service/Core/Search/SearchClient.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index e316d5502..215dab47e 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -254,11 +254,11 @@ public async IAsyncEnumerable AskAsyncChunk( this._log.LogTrace("{Count} records processed", result.RecordCount); var charsGenerated = 0; var wholeText = new StringBuilder(); - await foreach (var x in this._answerGenerator.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(true)) + await foreach (var token in this._answerGenerator.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(true)) { var text = new StringBuilder(); - text.Append(x); - wholeText.Append(x); + text.Append(token); + wholeText.Append(token); if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) { charsGenerated = text.Length; From 987599afbb30971cbab2dc0b2b059e97ac27227a Mon Sep 17 00:00:00 2001 From: carlodek <56030624+carlodek@users.noreply.github.com> Date: Fri, 29 Nov 2024 09:24:20 +0100 Subject: [PATCH 28/30] Update Service.csproj removed unnecessary blank line From 92c0d33b4f8fa112de3ece137bc5b0b341a249da Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Fri, 29 Nov 2024 14:36:34 -0800 Subject: [PATCH 29/30] Undo changes --- service/Service/Service.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/Service/Service.csproj b/service/Service/Service.csproj index 11c7f7261..11739eb76 100644 --- a/service/Service/Service.csproj +++ b/service/Service/Service.csproj @@ -53,4 +53,4 @@ - + \ No newline at end of file From a4a4bc7fe89d32bfa03cd8d5b455795ffc38adc7 Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Sat, 30 Nov 2024 19:53:45 -0800 Subject: [PATCH 30/30] Refactoring - reduce code duplication - reduce stream payload size - stream reset on moderation - handle errors - add streaming examples --- .github/_typos.toml | 1 + KernelMemory.sln.DotSettings | 1 + .../SemanticKernelPlugin/MemoryPlugin.cs | 1 + clients/dotnet/WebClient/MemoryWebClient.cs | 75 ++----- examples/001-dotnet-WebClient/Program.cs | 46 +++- examples/002-dotnet-Serverless/Program.cs | 46 +++- .../AzureAISearch.TestApplication/Program.cs | 3 +- .../AzureAISearch/AzureAISearchMemory.cs | 11 +- service/Abstractions/Abstractions.csproj | 3 +- service/Abstractions/Constants.cs | 5 - service/Abstractions/Context/IContext.cs | 10 - service/Abstractions/HTTP/SSE.cs | 68 ++++++ service/Abstractions/IKernelMemory.cs | 18 +- .../Abstractions/KernelMemoryExtensions.cs | 43 ++++ service/Abstractions/Models/MemoryAnswer.cs | 37 +++- service/Abstractions/Models/MemoryQuery.cs | 4 + service/Abstractions/Models/SearchResult.cs | 6 +- service/Abstractions/Models/StreamStates.cs | 58 +++++ service/Abstractions/Search/ISearchClient.cs | 12 +- service/Abstractions/Search/SearchOptions.cs | 29 +++ service/Core/Configuration/ServiceConfig.cs | 5 + service/Core/MemoryServerless.cs | 46 ++-- service/Core/MemoryService.cs | 45 ++-- service/Core/Search/AnswerGenerator.cs | 85 +++----- service/Core/Search/SearchClient.cs | 206 +++++++----------- service/Core/Search/SearchClientResult.cs | 66 +++++- service/Service.AspNetCore/WebAPIEndpoints.cs | 126 +++++------ service/Service/appsettings.json | 2 + .../Abstractions.UnitTests/Http/SSETest.cs | 149 +++++++++++++ 29 files changed, 786 insertions(+), 421 deletions(-) create mode 100644 service/Abstractions/HTTP/SSE.cs create mode 100644 service/Abstractions/Models/StreamStates.cs create mode 100644 service/Abstractions/Search/SearchOptions.cs create mode 100644 service/tests/Abstractions.UnitTests/Http/SSETest.cs diff --git a/.github/_typos.toml b/.github/_typos.toml index 5fa7634f9..000d278e9 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -15,6 +15,7 @@ extend-exclude = [ "encoder.json", "appsettings.development.json", "appsettings.Development.json", + "appsettings.*.json.*", "AzureAISearchFilteringTest.cs", "KernelMemory.sln.DotSettings" ] diff --git a/KernelMemory.sln.DotSettings b/KernelMemory.sln.DotSettings index 4d9e859e9..56b15a2c0 100644 --- a/KernelMemory.sln.DotSettings +++ b/KernelMemory.sln.DotSettings @@ -120,6 +120,7 @@ SHA SK SKHTTP + SSE SSL TTL UI diff --git a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs index 4446459ab..a4db5b6a0 100644 --- a/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs +++ b/clients/dotnet/SemanticKernelPlugin/MemoryPlugin.cs @@ -404,6 +404,7 @@ public async Task AskAsync( MemoryAnswer answer = await this._memory.AskAsync( question: question, index: index ?? this._defaultIndex, + options: new SearchOptions { Stream = false }, filter: TagsToMemoryFilter(tags ?? this._defaultRetrievalTags), minRelevance: minRelevance, cancellationToken: cancellationToken).ConfigureAwait(false); diff --git a/clients/dotnet/WebClient/MemoryWebClient.cs b/clients/dotnet/WebClient/MemoryWebClient.cs index d4e157cf4..f767881fd 100644 --- a/clients/dotnet/WebClient/MemoryWebClient.cs +++ b/clients/dotnet/WebClient/MemoryWebClient.cs @@ -7,12 +7,13 @@ using System.Linq; using System.Net; using System.Net.Http; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using System.Runtime.CompilerServices; using Microsoft.KernelMemory.Context; +using Microsoft.KernelMemory.HTTP; using Microsoft.KernelMemory.Internals; namespace Microsoft.KernelMemory; @@ -338,28 +339,30 @@ public async Task SearchAsync( } /// - public async Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } + var useStreaming = options?.Stream ?? false; MemoryQuery request = new() { Index = index, Question = question, Filters = (filters is { Count: > 0 }) ? filters.ToList() : [], MinRelevance = minRelevance, + Stream = useStreaming, ContextArguments = (context?.Arguments ?? new Dictionary()).ToDictionary(), }; using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json"); @@ -368,62 +371,20 @@ public async Task AskAsync( HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false); response.EnsureSuccessStatusCode(); - var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); - return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); - } - - /// - public async IAsyncEnumerable AskAsyncChunk( - string question, - string? index = null, - MemoryFilter? filter = null, - ICollection? filters = null, - double minRelevance = 0, - IContext? context = null, - [EnumeratorCancellation] CancellationToken cancellationToken = default) - { - if (filter != null) + if (useStreaming) { - if (filters == null) { filters = new List(); } - - filters.Add(filter); - } - - MemoryQuery request = new() - { - Index = index, - Question = question, - Filters = (filters is { Count: > 0 }) ? filters.ToList() : new(), - MinRelevance = minRelevance, - ContextArguments = (context?.Arguments ?? new Dictionary()).ToDictionary(), - }; - using StringContent content = new(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json"); - - var url = Constants.HttpAskChunkEndpoint.CleanUrlPath(); - HttpResponseMessage response = await this._client.PostAsync(url, content, cancellationToken).ConfigureAwait(false); - response.EnsureSuccessStatusCode(); - var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); - var eosToken = context.GetCustomEosTokenOrDefault("#DONE#"); - using (var reader = new StreamReader(stream, Encoding.UTF8)) - { - string? line; - while ((line = await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false)) != null) + Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + IAsyncEnumerable answers = SSE.ParseStreamAsync(stream, cancellationToken); + await foreach (MemoryAnswer answer in answers.ConfigureAwait(false)) { - if (line.StartsWith("data:", StringComparison.Ordinal)) - { - var jsonData = line.Substring(6); - var chunk = JsonSerializer.Deserialize(jsonData); - if (chunk != null) - { - yield return chunk; - if (chunk.Result == eosToken) - { - break; - } - } - } + yield return answer; } } + else + { + var json = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false); + yield return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) ?? new MemoryAnswer(); + } } #region private diff --git a/examples/001-dotnet-WebClient/Program.cs b/examples/001-dotnet-WebClient/Program.cs index 0eec905c4..f1e6d4db4 100644 --- a/examples/001-dotnet-WebClient/Program.cs +++ b/examples/001-dotnet-WebClient/Program.cs @@ -15,7 +15,7 @@ * without extracting memories. */ internal static class Program { - private static MemoryWebClient? s_memory; + private static MemoryWebClient s_memory = null!; private static readonly List s_toDelete = []; // Change this to True and configure Azure Document Intelligence to test OCR and support for images @@ -55,8 +55,8 @@ public static async Task Main() // === RETRIEVAL ========= // ======================= - await AskSimpleQuestion(); - await AskSimpleQuestionAndShowSources(); + await AskSimpleQuestionStreamingTheAnswer(); + await AskSimpleQuestionStreamingAndShowSources(); await AskQuestionAboutImageContent(); await AskQuestionUsingFilter(); await AskQuestionsFilteringByUser(); @@ -249,16 +249,25 @@ private static async Task StoreJson() // ======================= // Question without filters - private static async Task AskSimpleQuestion() + private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: formula explanation using the information loaded"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.6); - Console.WriteLine($"\nAnswer: {answer.Result}"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6, + options: new SearchOptions { Stream = true }); - Console.WriteLine("\n====================================\n"); + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + // Slow down the stream for demo purpose + await Task.Delay(25); + } + + Console.WriteLine("\n\n====================================\n"); /* OUTPUT @@ -275,17 +284,32 @@ due to the speed of light being a very large number when squared. This concept i } // Another question without filters and show sources - private static async Task AskSimpleQuestionAndShowSources() + private static async Task AskSimpleQuestionStreamingAndShowSources() { var question = "What's Kernel Memory?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.5); - Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5, + options: new SearchOptions { Stream = true }); + + List sources = []; + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + + // Collect sources + sources.AddRange(answer.RelevantSources); + + // Slow down the stream for demo purpose + await Task.Delay(5); + } // Show sources / citations - foreach (var x in answer.RelevantSources) + Console.WriteLine("\n\nSources:\n"); + foreach (var x in sources) { Console.WriteLine(x.SourceUrl != null ? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]" diff --git a/examples/002-dotnet-Serverless/Program.cs b/examples/002-dotnet-Serverless/Program.cs index ee5f680d6..26a9fb5c8 100644 --- a/examples/002-dotnet-Serverless/Program.cs +++ b/examples/002-dotnet-Serverless/Program.cs @@ -13,7 +13,7 @@ #pragma warning disable CS8602 // by design public static class Program { - private static MemoryServerless? s_memory; + private static MemoryServerless s_memory = null!; private static readonly List s_toDelete = []; // Remember to configure Azure Document Intelligence to test OCR and support for images @@ -107,8 +107,8 @@ public static async Task Main() // === RETRIEVAL ========= // ======================= - await AskSimpleQuestion(); - await AskSimpleQuestionAndShowSources(); + await AskSimpleQuestionStreamingTheAnswer(); + await AskSimpleQuestionStreamingAndShowSources(); await AskQuestionAboutImageContent(); await AskQuestionUsingFilter(); await AskQuestionsFilteringByUser(); @@ -303,16 +303,25 @@ private static async Task StoreJson() // ======================= // Question without filters - private static async Task AskSimpleQuestion() + private static async Task AskSimpleQuestionStreamingTheAnswer() { var question = "What's E = m*c^2?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: formula explanation using the information loaded"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.6); - Console.WriteLine($"\nAnswer: {answer.Result}"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.6, + options: new SearchOptions { Stream = true }); - Console.WriteLine("\n====================================\n"); + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + // Slow down the stream for demo purpose + await Task.Delay(25); + } + + Console.WriteLine("\n\n====================================\n"); /* OUTPUT @@ -329,17 +338,32 @@ due to the speed of light being a very large number when squared. This concept i } // Another question without filters and show sources - private static async Task AskSimpleQuestionAndShowSources() + private static async Task AskSimpleQuestionStreamingAndShowSources() { var question = "What's Kernel Memory?"; Console.WriteLine($"Question: {question}"); Console.WriteLine($"Expected result: it should explain what KM project is (not generic kernel memory)"); - var answer = await s_memory.AskAsync(question, minRelevance: 0.5); - Console.WriteLine($"\nAnswer: {answer.Result}\n\n Sources:\n"); + Console.Write("\nAnswer: "); + var answerStream = s_memory.AskStreamingAsync(question, minRelevance: 0.5, + options: new SearchOptions { Stream = true }); + + List sources = []; + await foreach (var answer in answerStream) + { + // Print token received by LLM + Console.Write(answer.Result); + + // Collect sources + sources.AddRange(answer.RelevantSources); + + // Slow down the stream for demo purpose + await Task.Delay(5); + } // Show sources / citations - foreach (var x in answer.RelevantSources) + Console.WriteLine("\n\nSources:\n"); + foreach (var x in sources) { Console.WriteLine(x.SourceUrl != null ? $" - {x.SourceUrl} [{x.Partitions.First().LastUpdate:D}]" diff --git a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs index 93618ad48..3aee0565b 100644 --- a/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs +++ b/extensions/AzureAISearch/AzureAISearch.TestApplication/Program.cs @@ -8,6 +8,7 @@ using Microsoft.KernelMemory; using Microsoft.KernelMemory.MemoryDb.AzureAISearch; using Microsoft.KernelMemory.MemoryStorage; +using AISearchOptions = Azure.Search.Documents.SearchOptions; namespace Microsoft.AzureAISearch.TestApplication; @@ -246,7 +247,7 @@ private static async Task> SearchByFieldValueAsync( fieldValue1 = fieldValue1.Replace("'", "''", StringComparison.Ordinal); fieldValue2 = fieldValue2.Replace("'", "''", StringComparison.Ordinal); - SearchOptions options = new() + AISearchOptions options = new() { Filter = fieldIsCollection ? $"{fieldName}/any(s: s eq '{fieldValue1}') and {fieldName}/any(s: s eq '{fieldValue2}')" diff --git a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs index fb838e862..b8ad546c6 100644 --- a/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs +++ b/extensions/AzureAISearch/AzureAISearch/AzureAISearchMemory.cs @@ -19,6 +19,7 @@ using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.DocumentStorage; using Microsoft.KernelMemory.MemoryStorage; +using AISearchOptions = Azure.Search.Documents.SearchOptions; namespace Microsoft.KernelMemory.MemoryDb.AzureAISearch; @@ -184,7 +185,7 @@ await client.IndexDocumentsAsync( Exhaustive = false }; - SearchOptions options = new() + AISearchOptions options = new() { VectorSearch = new() { @@ -246,7 +247,7 @@ public async IAsyncEnumerable GetListAsync( { var client = this.GetSearchClient(index); - SearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit); + AISearchOptions options = this.PrepareSearchOptions(null, withEmbeddings, filters, limit); Response>? searchResult = null; try @@ -596,13 +597,13 @@ at Azure.Search.Documents.SearchClient.SearchInternal[T](SearchOptions options, return indexSchema; } - private SearchOptions PrepareSearchOptions( - SearchOptions? options, + private AISearchOptions PrepareSearchOptions( + AISearchOptions? options, bool withEmbeddings, ICollection? filters = null, int limit = 1) { - options ??= new SearchOptions(); + options ??= new AISearchOptions(); // Define which fields to fetch options.Select.Add(AzureAISearchMemoryRecord.IdField); diff --git a/service/Abstractions/Abstractions.csproj b/service/Abstractions/Abstractions.csproj index 6fab4f39c..caac4445d 100644 --- a/service/Abstractions/Abstractions.csproj +++ b/service/Abstractions/Abstractions.csproj @@ -4,7 +4,7 @@ net8.0 Microsoft.KernelMemory.Abstractions Microsoft.KernelMemory - $(NoWarn);KMEXP00;CA1711;CA1724;CS1574;SKEXP0001; + $(NoWarn);KMEXP00;SKEXP0001;CA1711;CA1724;CS1574;CA1812; @@ -13,6 +13,7 @@ + diff --git a/service/Abstractions/Constants.cs b/service/Abstractions/Constants.cs index 8bcc676ef..408c1b8dc 100644 --- a/service/Abstractions/Constants.cs +++ b/service/Abstractions/Constants.cs @@ -61,9 +61,6 @@ public static class TextGeneration public static class Rag { - // Used to override EosToken config for streaming - public const string EosToken = "custom_rag_eos_token_str"; - // Used to override No Answer config public const string EmptyAnswer = "custom_rag_empty_answer_str"; @@ -131,8 +128,6 @@ public static class Summary public const string ReservedPayloadVectorGeneratorField = "vector_generator"; // Endpoints - - public const string HttpAskChunkEndpoint = "/ask_stream"; public const string HttpAskEndpoint = "/ask"; public const string HttpSearchEndpoint = "/search"; public const string HttpDownloadEndpoint = "/download"; diff --git a/service/Abstractions/Context/IContext.cs b/service/Abstractions/Context/IContext.cs index d3ce73356..45ebc3ddb 100644 --- a/service/Abstractions/Context/IContext.cs +++ b/service/Abstractions/Context/IContext.cs @@ -100,16 +100,6 @@ public static string GetCustomEmptyAnswerTextOrDefault(this IContext? context, s return defaultValue; } - public static string GetCustomEosTokenOrDefault(this IContext? context, string defaultValue) - { - if (context.TryGetArg(Constants.CustomContext.Rag.EosToken, out var customValue)) - { - return customValue; - } - - return defaultValue; - } - public static string GetCustomRagFactTemplateOrDefault(this IContext? context, string defaultValue) { if (context.TryGetArg(Constants.CustomContext.Rag.FactTemplate, out var customValue)) diff --git a/service/Abstractions/HTTP/SSE.cs b/service/Abstractions/HTTP/SSE.cs new file mode 100644 index 000000000..e6d9254fd --- /dev/null +++ b/service/Abstractions/HTTP/SSE.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Text.Json; +using System.Threading; + +namespace Microsoft.KernelMemory.HTTP; + +// See https://developer.mozilla.org/docs/Web/API/Server-sent_events/Using_server-sent_events +public static class SSE +{ + public const string DataPrefix = "data: "; + public const string LastToken = "[DONE]"; + public const string DoneMessage = $"{DataPrefix}{LastToken}"; + + public async static IAsyncEnumerable ParseStreamAsync( + Stream stream, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using var reader = new StreamReader(stream, Encoding.UTF8); + StringBuilder buffer = new(); + + while (await reader.ReadLineAsync(cancellationToken).ConfigureAwait(false) is { } line) + { + if (string.IsNullOrWhiteSpace(line)) // \n\n detected => Message delimiter + { + if (buffer.Length == 0) { continue; } + + string message = buffer.ToString(); + buffer.Clear(); + if (message.Trim() == DoneMessage) { yield break; } + + var memoryAnswer = ParseMessage(message); + if (memoryAnswer != null) { yield return memoryAnswer; } + } + else + { + buffer.AppendLine(line); + } + } + + // Process any remaining text as the last message + if (buffer.Length > 0) + { + string message = buffer.ToString(); + if (message.Trim() == DoneMessage) { yield break; } + + var memoryAnswer = ParseMessage(message); + if (memoryAnswer != null) { yield return memoryAnswer; } + } + } + + public static T? ParseMessage(string? message) + { + if (string.IsNullOrWhiteSpace(message)) { return default; } + + string json = string.Join("", + message.Split('\n', StringSplitOptions.RemoveEmptyEntries) + .Where(line => line.StartsWith(DataPrefix, StringComparison.OrdinalIgnoreCase)) + .Select(line => line[DataPrefix.Length..])); + + return JsonSerializer.Deserialize(json); + } +} diff --git a/service/Abstractions/IKernelMemory.cs b/service/Abstractions/IKernelMemory.cs index 11e183c84..f1152ff6a 100644 --- a/service/Abstractions/IKernelMemory.cs +++ b/service/Abstractions/IKernelMemory.cs @@ -211,30 +211,28 @@ public Task SearchAsync( /// /// Search the given index for an answer to the given query. + /// + /// Use this method to work with IAsyncEnumerable and optionally stream the output. + /// - Note: you must set options.Stream = true to enable token streaming. + /// + /// Use the .AskAsync() extension method to receive the complete answer without streaming. /// /// Question to answer /// Optional index name /// Filter to match /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list. /// Minimum Cosine Similarity required + /// Options for the request, such as whether to stream results /// Unstructured data supporting custom business logic in the current request. /// Async task cancellation token /// Answer to the query, if possible - public Task AskAsync( - string question, - string? index = null, - MemoryFilter? filter = null, - ICollection? filters = null, - double minRelevance = 0, - IContext? context = null, - CancellationToken cancellationToken = default); - - public IAsyncEnumerable AskAsyncChunk( + public IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, CancellationToken cancellationToken = default); } diff --git a/service/Abstractions/KernelMemoryExtensions.cs b/service/Abstractions/KernelMemoryExtensions.cs index 62443a890..0e5dfb53a 100644 --- a/service/Abstractions/KernelMemoryExtensions.cs +++ b/service/Abstractions/KernelMemoryExtensions.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft. All rights reserved. using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; +using Microsoft.KernelMemory.Context; namespace Microsoft.KernelMemory; @@ -11,6 +13,47 @@ namespace Microsoft.KernelMemory; /// public static class KernelMemoryExtensions { + /// + /// Search the given index for an answer to the given query + /// and return it without streaming the content. + /// + /// Memory instance + /// Question to answer + /// Optional index name + /// Filter to match + /// Filters to match (using inclusive OR logic). If 'filter' is provided too, the value is merged into this list. + /// Minimum Cosine Similarity required + /// Options for the request, such as whether to stream results + /// Unstructured data supporting custom business logic in the current request. + /// Async task cancellation token + /// Answer to the query, if possible + public static async Task AskAsync( + this IKernelMemory memory, + string question, + string? index = null, + MemoryFilter? filter = null, + ICollection? filters = null, + double minRelevance = 0, + SearchOptions? options = null, + IContext? context = null, + CancellationToken cancellationToken = default) + { + var optionsOverride = options.Clone() ?? new SearchOptions(); + optionsOverride.Stream = false; + + return await memory.AskStreamingAsync( + question: question, + index: index, + filter: filter, + filters: filters, + minRelevance: minRelevance, + options: optionsOverride, + context: context, + cancellationToken) + .FirstAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + } + /// /// Return a list of synthetic memories of the specified type /// diff --git a/service/Abstractions/Models/MemoryAnswer.cs b/service/Abstractions/Models/MemoryAnswer.cs index eeb9d4d03..c78e695c6 100644 --- a/service/Abstractions/Models/MemoryAnswer.cs +++ b/service/Abstractions/Models/MemoryAnswer.cs @@ -11,15 +11,20 @@ namespace Microsoft.KernelMemory; public class MemoryAnswer { - private static readonly JsonSerializerOptions s_indentedJsonOptions = new() { WriteIndented = true }; - private static readonly JsonSerializerOptions s_notIndentedJsonOptions = new() { WriteIndented = false }; - private static readonly JsonSerializerOptions s_caseInsensitiveJsonOptions = new() { PropertyNameCaseInsensitive = true }; + /// + /// Used only when streaming. How to handle the current record. + /// + [JsonPropertyName("streamState")] + [JsonPropertyOrder(0)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + public StreamStates? StreamState { get; set; } = null; /// /// Client question. /// [JsonPropertyName("question")] [JsonPropertyOrder(1)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public string Question { get; set; } = string.Empty; [JsonPropertyName("noResult")] @@ -48,23 +53,31 @@ public class MemoryAnswer /// [JsonPropertyName("relevantSources")] [JsonPropertyOrder(20)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public List RelevantSources { get; set; } = []; /// /// Serialize using .NET JSON serializer, e.g. to avoid ambiguity /// with other serializers and other options /// - /// Whether to keep the JSON readable, e.g. for debugging and views + /// Whether to reduce the payload size for SSE /// JSON serialization - public string ToJson(bool indented = false) + public string ToJson(bool optimizeForStream) { - return JsonSerializer.Serialize(this, indented ? s_indentedJsonOptions : s_notIndentedJsonOptions); - } + if (!optimizeForStream || this.StreamState != StreamStates.Append) + { + return JsonSerializer.Serialize(this); + } - public MemoryAnswer FromJson(string json) - { - return JsonSerializer.Deserialize(json, s_caseInsensitiveJsonOptions) - ?? new MemoryAnswer(); + MemoryAnswer clone = JsonSerializer.Deserialize(JsonSerializer.Serialize(this))!; + +#pragma warning disable CA1820 + if (clone.Question == string.Empty) { clone.Question = null!; } +#pragma warning restore CA1820 + + if (clone.RelevantSources.Count == 0) { clone.RelevantSources = null!; } + + return JsonSerializer.Serialize(clone); } public override string ToString() @@ -72,7 +85,7 @@ public override string ToString() var result = new StringBuilder(); result.AppendLine(this.Result); - if (!this.NoResult) + if (!this.NoResult && this.RelevantSources is { Count: > 0 }) { var sources = new Dictionary(); foreach (var x in this.RelevantSources) diff --git a/service/Abstractions/Models/MemoryQuery.cs b/service/Abstractions/Models/MemoryQuery.cs index f6e8dfae4..3676b945a 100644 --- a/service/Abstractions/Models/MemoryQuery.cs +++ b/service/Abstractions/Models/MemoryQuery.cs @@ -23,6 +23,10 @@ public class MemoryQuery [JsonPropertyOrder(2)] public double MinRelevance { get; set; } = 0; + [JsonPropertyName("stream")] + [JsonPropertyOrder(3)] + public bool Stream { get; set; } = false; + [JsonPropertyName("args")] [JsonPropertyOrder(100)] public Dictionary ContextArguments { get; set; } = []; diff --git a/service/Abstractions/Models/SearchResult.cs b/service/Abstractions/Models/SearchResult.cs index 019720c22..4e2e35128 100644 --- a/service/Abstractions/Models/SearchResult.cs +++ b/service/Abstractions/Models/SearchResult.cs @@ -26,10 +26,7 @@ public class SearchResult [JsonPropertyOrder(2)] public bool NoResult { - get - { - return this.Results.Count == 0; - } + get => this.Results == null || this.Results.Count == 0; private set { } } @@ -40,6 +37,7 @@ private set { } /// [JsonPropertyName("results")] [JsonPropertyOrder(3)] + [JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] public List Results { get; set; } = []; /// diff --git a/service/Abstractions/Models/StreamStates.cs b/service/Abstractions/Models/StreamStates.cs new file mode 100644 index 000000000..041187dd7 --- /dev/null +++ b/service/Abstractions/Models/StreamStates.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Text.Json; +using System.Text.Json.Serialization; + +namespace Microsoft.KernelMemory; + +[JsonConverter(typeof(StreamStatesConverter))] +public enum StreamStates +{ + // Inform the client the stream ended to an error. + Error = 0, + + // When streaming, inform the client to discard any previous data + // and start collecting again using this record as the first one. + Reset = 1, + + // When streaming, append the current result to the data + // already received so far. + Append = 2, + + // Inform the client the end of the stream has been reached + // and that this is the last record to append. + Last = 3, +} + +#pragma warning disable CA1308 +internal sealed class StreamStatesConverter : JsonConverter +{ + public override StreamStates Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + { + string value = reader.GetString()!; + return value.ToLowerInvariant() switch + { + "error" => StreamStates.Error, + "reset" => StreamStates.Reset, + "append" => StreamStates.Append, + "last" => StreamStates.Last, + _ => throw new JsonException($"Unknown {nameof(StreamStates)} value: {value}") + }; + } + + public override void Write(Utf8JsonWriter writer, StreamStates value, JsonSerializerOptions options) + { + string serializedValue = value switch + { + StreamStates.Error => "error", + StreamStates.Reset => "reset", + StreamStates.Append => "append", + StreamStates.Last => "last", + _ => throw new JsonException($"Cannot serialize {nameof(StreamStates)} value: {value}") + }; + + writer.WriteStringValue(serializedValue); + } +} +#pragma warning restore CA1308 diff --git a/service/Abstractions/Search/ISearchClient.cs b/service/Abstractions/Search/ISearchClient.cs index b52f9134e..4329538e3 100644 --- a/service/Abstractions/Search/ISearchClient.cs +++ b/service/Abstractions/Search/ISearchClient.cs @@ -50,7 +50,17 @@ Task AskAsync( IContext? context = null, CancellationToken cancellationToken = default); - IAsyncEnumerable AskAsyncChunk( + /// + /// Answer the given question, if possible, grounding the response with relevant memories matching the given criteria. + /// + /// Index (aka collection) to search for grounding information + /// Question to answer + /// Filtering criteria to select memories to consider + /// Minimum relevance of the memories considered + /// Optional context carrying optional information used by internal logic + /// Async task cancellation token + /// Answer to the given question + IAsyncEnumerable AskStreamingAsync( string index, string question, ICollection? filters = null, diff --git a/service/Abstractions/Search/SearchOptions.cs b/service/Abstractions/Search/SearchOptions.cs new file mode 100644 index 000000000..4f88b1f0c --- /dev/null +++ b/service/Abstractions/Search/SearchOptions.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft. All rights reserved. + +#pragma warning disable IDE0130 // reduce number of "using" statements +// ReSharper disable once CheckNamespace - reduce number of "using" statements +namespace Microsoft.KernelMemory; + +// TODO: move minRelevance to this class +// TODO: move filter to this class +// TODO: move filters to this class +public sealed class SearchOptions +{ + /// + /// Whether to stream results back to the client + /// + public bool Stream { get; set; } = false; +} + +public static class SearchOptionsExtensions +{ + public static SearchOptions? Clone(this SearchOptions? options) + { + if (options == null) { return null; } + + return new SearchOptions + { + Stream = options.Stream + }; + } +} diff --git a/service/Core/Configuration/ServiceConfig.cs b/service/Core/Configuration/ServiceConfig.cs index 4b89ff129..2a9ab78ec 100644 --- a/service/Core/Configuration/ServiceConfig.cs +++ b/service/Core/Configuration/ServiceConfig.cs @@ -24,6 +24,11 @@ public class ServiceConfig /// public bool OpenApiEnabled { get; set; } = false; + /// + /// Whether to send a [DONE] message at the end of SSE streams. + /// + public bool SendSSEDoneMessage { get; set; } = true; + /// /// List of handlers to enable /// diff --git a/service/Core/MemoryServerless.cs b/service/Core/MemoryServerless.cs index 92874211b..240ddd52a 100644 --- a/service/Core/MemoryServerless.cs +++ b/service/Core/MemoryServerless.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -262,56 +263,47 @@ public Task SearchAsync( } /// - public Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { this._contextProvider.InitContext(context); if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsync( - index: index, - question: question, - filters: filters, - minRelevance: minRelevance, - context: context, - cancellationToken: cancellationToken); - } - public IAsyncEnumerable AskAsyncChunk( - string question, - string? index = null, - MemoryFilter? filter = null, - ICollection? filters = null, - double minRelevance = 0, - IContext? context = null, - CancellationToken cancellationToken = default) - { - if (filter != null) + if (options is { Stream: true }) { - if (filters == null) { filters = new List(); } + await foreach (var answer in this._searchClient.AskStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken).ConfigureAwait(false)) + { + yield return answer; + } - filters.Add(filter); + yield break; } - index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsyncChunk( + yield return await this._searchClient.AskAsync( index: index, question: question, filters: filters, minRelevance: minRelevance, context: context, - cancellationToken: cancellationToken); + cancellationToken).ConfigureAwait(false); } } diff --git a/service/Core/MemoryService.cs b/service/Core/MemoryService.cs index 2ddf6b433..5a88813da 100644 --- a/service/Core/MemoryService.cs +++ b/service/Core/MemoryService.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -228,58 +229,46 @@ public Task SearchAsync( } /// - public Task AskAsync( + public async IAsyncEnumerable AskStreamingAsync( string question, string? index = null, MemoryFilter? filter = null, ICollection? filters = null, double minRelevance = 0, + SearchOptions? options = null, IContext? context = null, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { if (filter != null) { - if (filters == null) { filters = []; } - + filters ??= []; filters.Add(filter); } index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsync( - index: index, - question: question, - filters: filters, - minRelevance: minRelevance, - context: context, - cancellationToken: cancellationToken); - } - public IAsyncEnumerable AskAsyncChunk( - string question, - string? index = null, - MemoryFilter? filter = null, - ICollection? filters = null, - double minRelevance = 0, - IContext? context = null, - CancellationToken cancellationToken = default) - { - if (filter != null) + if (options is { Stream: true }) { - if (filters == null) + await foreach (var answer in this._searchClient.AskStreamingAsync( + index: index, + question: question, + filters: filters, + minRelevance: minRelevance, + context: context, + cancellationToken).ConfigureAwait(false)) { - filters = new List(); + yield return answer; } - filters.Add(filter); + yield break; } - index = IndexName.CleanName(index, this._defaultIndexName); - return this._searchClient.AskAsyncChunk( + yield return await this._searchClient.AskAsync( index: index, question: question, filters: filters, minRelevance: minRelevance, context: context, - cancellationToken: cancellationToken); + cancellationToken).ConfigureAwait(false); } } diff --git a/service/Core/Search/AnswerGenerator.cs b/service/Core/Search/AnswerGenerator.cs index e3b13e7e3..352160851 100644 --- a/service/Core/Search/AnswerGenerator.cs +++ b/service/Core/Search/AnswerGenerator.cs @@ -2,8 +2,8 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -44,89 +44,72 @@ public AnswerGenerator( { throw new KernelMemoryException("Text generator not configured"); } + + if (this._contentModeration == null || !this._config.UseContentModeration) + { + this._log.LogInformation("Content moderation is not enabled."); + } } - internal async Task GenerateAnswerAsync( - string question, SearchClientResult result, IContext? context, CancellationToken cancellationToken) + internal async IAsyncEnumerable GenerateAnswerAsync( + string question, SearchClientResult result, + IContext? context, [EnumeratorCancellation] CancellationToken cancellationToken) { if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0) { this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - result.AskResult.NoResultReason = "Unable to use memories"; - return result.AskResult; + yield return result.InsufficientTokensResult; + yield break; } if (result.FactsUsedCount == 0) { this._log.LogWarning("No memories available"); - result.AskResult.NoResultReason = "No memories available"; - return result.AskResult; + yield return result.NoFactsResult; + yield break; } - // Collect the LLM output - var text = new StringBuilder(); - var charsGenerated = 0; - var watch = new Stopwatch(); - watch.Restart(); - await foreach (var x in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + var completeAnswer = new StringBuilder(); + await foreach (var answerToken in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) { - text.Append(x); - - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) - { - charsGenerated = text.Length; - this._log.LogTrace("{0} chars generated", charsGenerated); - } + completeAnswer.Append(answerToken); + result.AskResult.Result = answerToken; + yield return result.AskResult; } - watch.Stop(); - // Finalize the answer, checking if it's empty - result.AskResult.Result = text.ToString(); - this._log.LogSensitive("Answer: {0}", result.AskResult.Result); - result.AskResult.NoResult = ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer); - if (result.AskResult.NoResult) - { - result.AskResult.NoResultReason = "No relevant memories found"; - this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds); - } - else + result.AskResult.Result = completeAnswer.ToString(); + if (string.IsNullOrWhiteSpace(result.AskResult.Result) + || ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer)) { - this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); + this._log.LogInformation("No relevant memories found, returning empty answer."); + yield return result.NoFactsResult; + yield break; } - return await this.ModeratedAnswerAsync(result.AskResult, cancellationToken).ConfigureAwait(false); - } + this._log.LogSensitive("Answer: {0}", result.AskResult.Result); - internal async Task ModeratedAnswerAsync(MemoryAnswer currentResult, CancellationToken cancellationToken) - { - if (this._contentModeration != null && this._config.UseContentModeration) + if (this._config.UseContentModeration + && this._contentModeration != null + && !await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false)) { - var isSafe = await this._contentModeration.IsSafeAsync(currentResult.Result, cancellationToken).ConfigureAwait(false); - if (!isSafe) - { - this._log.LogWarning("Unsafe answer detected. Returning error message instead."); - this._log.LogSensitive("Unsafe answer: {0}", currentResult.Result); - currentResult.NoResultReason = "Content moderation failure"; - currentResult.Result = this._config.ModeratedAnswer; - currentResult.NoResult = true; - } + this._log.LogWarning("Unsafe answer detected. Returning error message instead."); + yield return result.UnsafeAnswerResult; } - - return currentResult; } - internal IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken) + private IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); + string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); double temperature = context.GetCustomRagTemperatureOrDefault(this._config.Temperature); double nucleusSampling = context.GetCustomRagNucleusSamplingOrDefault(this._config.TopP); - string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); - prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); question = question.Trim(); question = question.EndsWith('?') ? question : $"{question}?"; + + prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); prompt = prompt.Replace("{{$notFound}}", emptyAnswer, StringComparison.OrdinalIgnoreCase); this._log.LogInformation("New prompt: {0}", prompt); diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index 215dab47e..f6b3ddae7 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -123,60 +123,74 @@ public async Task AskAsync( IContext? context = null, CancellationToken cancellationToken = default) { - string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); - string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); - int limit = context.GetCustomRagMaxMatchesCountOrDefault(this._config.MaxMatchesCount); + var result = new MemoryAnswer(); - var maxTokens = this._config.MaxAskPromptSize > 0 - ? this._config.MaxAskPromptSize - : this._textGenerator.MaxTokenTotal; + var stream = this.AskStreamingAsync( + index: index, question: question, filters, minRelevance, context, cancellationToken) + .ConfigureAwait(false); - SearchClientResult result = SearchClientResult.AskResultInstance( - question: question, - emptyAnswer: emptyAnswer, - maxGroundingFacts: limit, - tokensAvailable: maxTokens - - this._textGenerator.CountTokens(answerPrompt) - - this._textGenerator.CountTokens(question) - - this._config.AnswerTokens - ); - - if (string.IsNullOrEmpty(question)) + var done = false; + StringBuilder text = new(result.Result); + await foreach (var part in stream.ConfigureAwait(false)) { - this._log.LogWarning("No question provided"); - return result.AskResult; - } + if (done) { break; } - this._log.LogTrace("Fetching relevant memories"); - IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( - index: index, - text: question, - filters: filters, - minRelevance: minRelevance, - limit: limit, - withEmbeddings: false, - cancellationToken: cancellationToken); + switch (part.StreamState) + { + case StreamStates.Error: + text.Clear(); + result = part; - string factTemplate = context.GetCustomRagFactTemplateOrDefault(this._config.FactTemplate); - if (!factTemplate.EndsWith('\n')) { factTemplate += "\n"; } + done = true; + break; - // Memories are sorted by relevance, starting from the most relevant - await foreach ((MemoryRecord memoryRecord, double recordRelevance) in matches.ConfigureAwait(false)) - { - result.State = SearchState.Continue; - result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, factTemplate); + case StreamStates.Reset: + text.Clear(); + text.Append(part.Result); + result = part; + break; - if (result.State == SearchState.SkipRecord) { continue; } + case StreamStates.Append: + result.NoResult = part.NoResult; + result.NoResultReason = part.NoResultReason; - if (result.State == SearchState.Stop) { break; } - } + text.Append(part.Result); + result.Result = text.ToString(); - this._log.LogTrace("{Count} records processed", result.RecordCount); + if (result.RelevantSources != null && part.RelevantSources != null) + { + result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList(); + } + + break; - return await this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false); + case StreamStates.Last: + result.NoResult = part.NoResult; + result.NoResultReason = part.NoResultReason; + + text.Append(part.Result); + result.Result = text.ToString(); + + if (result.RelevantSources != null && part.RelevantSources != null) + { + result.RelevantSources = result.RelevantSources.Union(part.RelevantSources).ToList(); + } + + done = true; + break; + + default: + throw new ArgumentOutOfRangeException(nameof(part.StreamState)); + } + } + + result.Question = question; + result.StreamState = null; + return result; } - public async IAsyncEnumerable AskAsyncChunk( + /// + public async IAsyncEnumerable AskStreamingAsync( string index, string question, ICollection? filters = null, @@ -184,30 +198,30 @@ public async IAsyncEnumerable AskAsyncChunk( IContext? context = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - this._log.LogInformation("question: '{0}'", question); string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); - string eosToken = context.GetCustomEosTokenOrDefault("#DONE#"); string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); int limit = context.GetCustomRagMaxMatchesCountOrDefault(this._config.MaxMatchesCount); + var maxTokens = this._config.MaxAskPromptSize > 0 ? this._config.MaxAskPromptSize : this._textGenerator.MaxTokenTotal; + + // Prepare results (empty, error, etc.) SearchClientResult result = SearchClientResult.AskResultInstance( question: question, emptyAnswer: emptyAnswer, + moderatedAnswer: this._config.ModeratedAnswer, maxGroundingFacts: limit, tokensAvailable: maxTokens - this._textGenerator.CountTokens(answerPrompt) - this._textGenerator.CountTokens(question) - this._config.AnswerTokens ); + if (string.IsNullOrEmpty(question)) { this._log.LogWarning("No question provided"); - yield return result.AskResult; - result.AskResult.Result = eosToken; - this._log.LogInformation("Eos token: '{0}", result.AskResult.Result); - yield return result.AskResult; + yield return result.NoQuestionResult; yield break; } @@ -217,7 +231,7 @@ public async IAsyncEnumerable AskAsyncChunk( text: question, filters: filters, minRelevance: minRelevance, - limit: this._config.MaxMatchesCount, + limit: limit, withEmbeddings: false, cancellationToken: cancellationToken); @@ -235,64 +249,24 @@ public async IAsyncEnumerable AskAsyncChunk( if (result.State == SearchState.Stop) { break; } } - if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0) - { - this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - result.AskResult.NoResultReason = "Unable to use memories"; - yield return result.AskResult; - yield break; - } - - if (result.FactsUsedCount == 0) - { - this._log.LogWarning("No memories available"); - result.AskResult.NoResultReason = "No memories available"; - yield return result.AskResult; - yield break; - } - this._log.LogTrace("{Count} records processed", result.RecordCount); - var charsGenerated = 0; - var wholeText = new StringBuilder(); - await foreach (var token in this._answerGenerator.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(true)) + + var first = true; + await foreach (var answer in this._answerGenerator.GenerateAnswerAsync(question, result, context, cancellationToken).ConfigureAwait(false)) { - var text = new StringBuilder(); - text.Append(token); - wholeText.Append(token); - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) - { - charsGenerated = text.Length; - this._log.LogTrace("{0} chars generated", charsGenerated); - } + yield return answer; - var newAnswer = new MemoryAnswer + if (first) { - Question = question, - NoResult = false, - Result = text.ToString() - }; - this._log.LogInformation("Chunk: '{0}", newAnswer.Result); - yield return newAnswer; - } + // Remove redundant data, sent only once in the first record, to reduce payload + first = false; - var current = new MemoryAnswer - { - Question = question, - NoResult = false, - Result = wholeText.ToString(), - }; - MemoryAnswer moderatedResult = await this._answerGenerator.ModeratedAnswerAsync(current, cancellationToken).ConfigureAwait(false); - result.AskResult.Result = moderatedResult.Result; - result.AskResult.NoResult = moderatedResult.NoResult; - result.AskResult.NoResultReason = moderatedResult.NoResultReason; - if (result.AskResult.NoResultReason == "Content moderation failure") - { - yield return result.AskResult; + // Note: we keep the sources in the other collections (e.g. AskResult.ErrorResult.RelevantSources), + // so in case of a stream reset the sources are sent again. + result.AskResult.RelevantSources.Clear(); + result.AskResult.Question = null!; + } } - - result.AskResult.Result = eosToken; - this._log.LogInformation("Eos token: '{0}", result.AskResult.Result); - yield return result.AskResult; } /// @@ -369,29 +343,17 @@ private SearchClientResult ProcessMemoryRecord( this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance); } - Citation? citation; - if (result.Mode == SearchMode.SearchMode) + var citation = result.Mode switch { - citation = result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.SearchResult.Results.Add(citation); - } - } - else if (result.Mode == SearchMode.AskMode) - { - // If the file is already in the list of citations, only add the partition - citation = result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.AskResult.RelevantSources.Add(citation); - } - } - else + SearchMode.SearchMode => result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile), + SearchMode.AskMode => result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile), + _ => throw new ArgumentOutOfRangeException(nameof(result.Mode)) + }; + + if (citation == null) { - throw new ArgumentOutOfRangeException(nameof(result.Mode)); + citation = new Citation(); + result.AddSource(citation); } citation.Index = index; diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs index c509809a7..66ff58f15 100644 --- a/service/Core/Search/SearchClientResult.cs +++ b/service/Core/Search/SearchClientResult.cs @@ -23,10 +23,16 @@ internal class SearchClientResult public SearchState State { get; set; } public int RecordCount { get; set; } - // Use by in Search and Ask mode - public MemoryAnswer AskResult { get; private init; } = new(); + // Use by Search and Ask mode public int MaxRecordCount { get; private init; } + public MemoryAnswer AskResult { get; private init; } = new(); + public MemoryAnswer NoFactsResult { get; private init; } = new(); + public MemoryAnswer NoQuestionResult { get; private init; } = new(); + public MemoryAnswer UnsafeAnswerResult { get; private init; } = new(); + public MemoryAnswer InsufficientTokensResult { get; private init; } = new(); + public MemoryAnswer ErrorResult { get; private init; } = new(); + // Use by Ask mode public SearchResult SearchResult { get; private init; } = new(); public StringBuilder Facts { get; } = new(); @@ -37,7 +43,9 @@ internal class SearchClientResult /// /// Create new instance in Ask mode /// - public static SearchClientResult AskResultInstance(string question, string emptyAnswer, int maxGroundingFacts, int tokensAvailable) + public static SearchClientResult AskResultInstance( + string question, string emptyAnswer, string moderatedAnswer, + int maxGroundingFacts, int tokensAvailable) { return new SearchClientResult { @@ -46,14 +54,64 @@ public static SearchClientResult AskResultInstance(string question, string empty MaxRecordCount = maxGroundingFacts, AskResult = new MemoryAnswer { + StreamState = StreamStates.Append, + Question = question, + NoResult = false + }, + NoFactsResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "No relevant memories available", + Result = emptyAnswer + }, + NoQuestionResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, Question = question, NoResult = true, NoResultReason = "No question provided", - Result = emptyAnswer, + Result = emptyAnswer + }, + InsufficientTokensResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "Unable to use memory, max tokens reached", + Result = emptyAnswer + }, + UnsafeAnswerResult = new MemoryAnswer + { + StreamState = StreamStates.Reset, + Question = question, + NoResult = true, + NoResultReason = "Content moderation", + Result = moderatedAnswer + }, + ErrorResult = new MemoryAnswer + { + StreamState = StreamStates.Error, + Question = question, + NoResult = true, + NoResultReason = "An error occurred" } }; } + /// + /// Add source to all the collections + /// + public void AddSource(Citation citation) + { + this.SearchResult.Results?.Add(citation); + this.AskResult.RelevantSources?.Add(citation); + this.InsufficientTokensResult.RelevantSources?.Add(citation); + this.UnsafeAnswerResult.RelevantSources?.Add(citation); + this.ErrorResult.RelevantSources?.Add(citation); + } + /// /// Create new instance in Search mode /// diff --git a/service/Service.AspNetCore/WebAPIEndpoints.cs b/service/Service.AspNetCore/WebAPIEndpoints.cs index a9e607d1b..8e63c1aa3 100644 --- a/service/Service.AspNetCore/WebAPIEndpoints.cs +++ b/service/Service.AspNetCore/WebAPIEndpoints.cs @@ -3,9 +3,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; -using System.Text.Json; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -16,6 +17,7 @@ using Microsoft.KernelMemory.Configuration; using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.DocumentStorage; +using Microsoft.KernelMemory.HTTP; using Microsoft.KernelMemory.Service.AspNetCore.Models; namespace Microsoft.KernelMemory.Service.AspNetCore; @@ -32,8 +34,7 @@ public static IEndpointRouteBuilder AddKernelMemoryEndpoints( builder.AddGetIndexesEndpoint(apiPrefix).AddFilters(filters); builder.AddDeleteIndexesEndpoint(apiPrefix).AddFilters(filters); builder.AddDeleteDocumentsEndpoint(apiPrefix).AddFilters(filters); - builder.AddAskEndpoint(apiPrefix).AddFilters(filters); - builder.AddAskChunkEndpoint(apiPrefix).AddFilters(filters); + builder.AddAskEndpoint(apiPrefix, kmConfig?.Service.SendSSEDoneMessage ?? true).AddFilters(filters); builder.AddSearchEndpoint(apiPrefix).AddFilters(filters); builder.AddUploadStatusEndpoint(apiPrefix).AddFilters(filters); builder.AddGetDownloadEndpoint(apiPrefix).AddFilters(filters); @@ -213,13 +214,14 @@ await service.DeleteDocumentAsync(documentId: documentId, index: index, cancella } public static RouteHandlerBuilder AddAskEndpoint( - this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter[]? filters = null) + this IEndpointRouteBuilder builder, string apiPrefix = "/", bool sseSendDoneMessage = true, IEndpointFilter[]? filters = null) { RouteGroupBuilder group = builder.MapGroup(apiPrefix); // Ask endpoint var route = group.MapPost(Constants.HttpAskEndpoint, - async Task ( + async Task ( + HttpContext httpContext, MemoryQuery query, IKernelMemory service, ILogger log, @@ -229,73 +231,75 @@ async Task ( // Allow internal classes to access custom arguments via IContextProvider contextProvider.InitContextArgs(query.ContextArguments); - log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); - MemoryAnswer answer = await service.AskAsync( - question: query.Question, - index: query.Index, - filters: query.Filters, - minRelevance: query.MinRelevance, - context: contextProvider.GetContext(), - cancellationToken: cancellationToken) - .ConfigureAwait(false); - return Results.Ok(answer); - }) - .Produces(StatusCodes.Status200OK) - .Produces(StatusCodes.Status401Unauthorized) - .Produces(StatusCodes.Status403Forbidden); + log.LogTrace("New ask request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); - return route; - } + IAsyncEnumerable answerStream = service.AskStreamingAsync( + question: query.Question, + index: query.Index, + filters: query.Filters, + minRelevance: query.MinRelevance, + options: new SearchOptions { Stream = query.Stream }, + context: contextProvider.GetContext(), + cancellationToken: cancellationToken); + + httpContext.Response.StatusCode = StatusCodes.Status200OK; - public static RouteHandlerBuilder AddAskChunkEndpoint( - this IEndpointRouteBuilder builder, string apiPrefix = "/", IEndpointFilter? authFilter = null) - { - RouteGroupBuilder group = builder.MapGroup(apiPrefix); - // Ask endpoint - var route = group.MapPost(Constants.HttpAskChunkEndpoint, - async Task ( - HttpContext httpContext, - MemoryQuery query, - IKernelMemory service, - ILogger log, - IContextProvider contextProvider, - CancellationToken cancellationToken) => - { - // Allow internal classes to access custom arguments via IContextProvider - contextProvider.InitContextArgs(query.ContextArguments); - httpContext.Response.ContentType = "text/event-stream; charset=utf-8"; - log.LogTrace("New search request, index '{0}', minRelevance {1}", query.Index, query.MinRelevance); try { - var answerStream = service.AskAsyncChunk( - question: query.Question, - index: query.Index, - filters: query.Filters, - minRelevance: query.MinRelevance, - context: contextProvider.GetContext(), - cancellationToken: cancellationToken - ).ConfigureAwait(false); - - await foreach (var answer in answerStream.ConfigureAwait(false)) + if (query.Stream) + { + httpContext.Response.ContentType = "text/event-stream; charset=utf-8"; + await foreach (var answer in answerStream.ConfigureAwait(false)) + { + string json = answer.ToJson(true); + await httpContext.Response.WriteAsync($"{SSE.DataPrefix}{json}\n\n", cancellationToken).ConfigureAwait(false); + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + } + } + else { - string json = JsonSerializer.Serialize(answer); - await httpContext.Response.WriteAsync($"data: {json}\n\n", cancellationToken).ConfigureAwait(false); - await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + httpContext.Response.ContentType = "application/json; charset=utf-8"; + MemoryAnswer answer = await answerStream.FirstAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + string json = answer.ToJson(false); + await httpContext.Response.WriteAsync(json, cancellationToken).ConfigureAwait(false); } } - catch (Exception ex) + catch (Exception e) + { + log.LogError(e, "An error occurred while preparing the response"); + + // Attempt to set the status code, in case the output hasn't started yet + httpContext.Response.StatusCode = StatusCodes.Status503ServiceUnavailable; + + var json = query.Stream + ? JsonSerializer.Serialize(new MemoryAnswer + { + StreamState = StreamStates.Error, + Question = query.Question, + NoResult = true, + NoResultReason = $"Error: {e.Message} [{e.GetType().FullName}]" + }) + : JsonSerializer.Serialize(new ProblemDetails + { + Status = StatusCodes.Status503ServiceUnavailable, + Title = "Service Unavailable", + Detail = $"{e.Message} [{e.GetType().FullName}]" + }); + + await httpContext.Response.WriteAsync(query.Stream ? $"{SSE.DataPrefix}{json}\n\n" : json, cancellationToken).ConfigureAwait(false); + } + + if (query.Stream && sseSendDoneMessage) { - log.LogError(ex, "Error occurred while streaming resp"); - string errorJson = JsonSerializer.Serialize(new { error = "An error occurred while processing the request." }); - await httpContext.Response.WriteAsync($"data: {errorJson}\n\n", cancellationToken).ConfigureAwait(false); - await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); + await httpContext.Response.WriteAsync($"{SSE.DoneMessage}\n\n", cancellationToken: cancellationToken).ConfigureAwait(false); } + + await httpContext.Response.Body.FlushAsync(cancellationToken).ConfigureAwait(false); }) - .Produces(StatusCodes.Status204NoContent) + .Produces(StatusCodes.Status200OK) .Produces(StatusCodes.Status401Unauthorized) - .Produces(StatusCodes.Status403Forbidden); - - if (authFilter != null) { route.AddEndpointFilter(authFilter); } + .Produces(StatusCodes.Status403Forbidden) + .Produces(StatusCodes.Status503ServiceUnavailable); return route; } diff --git a/service/Service/appsettings.json b/service/Service/appsettings.json index 506f52c2d..db88ad35c 100644 --- a/service/Service/appsettings.json +++ b/service/Service/appsettings.json @@ -48,6 +48,8 @@ // If not set the solution defaults to 30,000,000 bytes (~28.6 MB) // Note: this applies only to KM HTTP service. "MaxUploadSizeMb": null, + // Whether to send a [DONE] message at the end of SSE streams. + "SendSSEDoneMessage": true, // Whether to run the asynchronous pipeline handlers // Use these booleans to deploy the web service and the handlers on same/different VMs "RunHandlers": true, diff --git a/service/tests/Abstractions.UnitTests/Http/SSETest.cs b/service/tests/Abstractions.UnitTests/Http/SSETest.cs new file mode 100644 index 000000000..ffa463737 --- /dev/null +++ b/service/tests/Abstractions.UnitTests/Http/SSETest.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.KernelMemory; +using Microsoft.KernelMemory.HTTP; + +namespace Microsoft.KM.Abstractions.UnitTests.Http; + +public class SSETest +{ + [Theory] + [InlineData(null)] + [InlineData("")] + [InlineData(" ")] + [InlineData(" \n")] + public void ItParsesEmptyStrings(string input) + { + Assert.Null(SSE.ParseMessage(input)); + } + + [Fact] + public void ItParsesSingleLineMessage() + { + // Arrange + var message = """ + data: { "question": "q" } + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + } + + [Fact] + public void ItParsesSingleLineMessageWithSeparator() + { + // Arrange + var message = """ + data: { "question": "q" } + + + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + } + + [Fact] + public void ItParsesMultiLineMessage() + { + // Arrange + var message = """ + data: { "question": "q" + data: , "noResultReason": "abc" + data: } + """; + + // Act + var x = SSE.ParseMessage(message); + + // Assert + Assert.NotNull(x); + Assert.Equal("q", x.Question); + Assert.Equal("abc", x.NoResultReason); + } + + [Theory] + [InlineData("data: [DONE]")] + [InlineData("data: [DONE]\n")] + [InlineData("data: [DONE]\n\n")] + public async Task ItParsesEmptyStreams(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(0, messages.Count); + } + + [Theory] + [InlineData("data: { \"question\": \"qq\" }")] + [InlineData("data: { \"question\": \"qq\" }\n")] + [InlineData("data: { \"question\": \"qq\" }\n\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: [DONE]\n\n")] + public async Task ItParsesStreamsWithASingleMessage(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(1, messages.Count); + Assert.NotNull(messages[0]); + Assert.Equal("qq", messages[0].Question); + } + + [Theory] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n")] + [InlineData("data: { \"question\": \"qq\" }\n\ndata: { \"question\": \"kk\" }\n\ndata: [DONE]\n\n")] + public async Task ItParsesStreamsWithMultipleMessage(string input) + { + // Arrange + using var stream = new MemoryStream(Encoding.UTF8.GetBytes(input)); + + // Act + var result = SSE.ParseStreamAsync(stream); + + // Assert + var messages = new List(); + await foreach (var message in result) + { + messages.Add(message); + } + + Assert.Equal(2, messages.Count); + Assert.NotNull(messages[0]); + Assert.Equal("qq", messages[0].Question); + Assert.NotNull(messages[1]); + Assert.Equal("kk", messages[1].Question); + } +}