From f2a77b490ac0c56ca55795824ff22a038232763b Mon Sep 17 00:00:00 2001 From: Devis Lucato Date: Sat, 23 Nov 2024 12:39:44 -0800 Subject: [PATCH] Refactor search client --- .../MemoryStorage/MemoryRecordExtensions.cs | 10 +- service/Core/Search/SearchClient.cs | 361 +++++++++--------- service/Core/Search/SearchClientResult.cs | 93 +++++ 3 files changed, 268 insertions(+), 196 deletions(-) create mode 100644 service/Core/Search/SearchClientResult.cs diff --git a/service/Core/MemoryStorage/MemoryRecordExtensions.cs b/service/Core/MemoryStorage/MemoryRecordExtensions.cs index 349f7cd42..4bdfb2084 100644 --- a/service/Core/MemoryStorage/MemoryRecordExtensions.cs +++ b/service/Core/MemoryStorage/MemoryRecordExtensions.cs @@ -72,14 +72,14 @@ public static string GetFileContentType(this MemoryRecord record, ILogger? log = /// public static string GetWebPageUrl(this MemoryRecord record, string indexName, ILogger? log = null) { - var fileDownloadUrl = Constants.HttpDownloadEndpointWithParams + var webPageUrl = record.GetPayloadValue(Constants.ReservedPayloadUrlField, log)?.ToString(); + + if (!string.IsNullOrWhiteSpace(webPageUrl)) { return webPageUrl; } + + return Constants.HttpDownloadEndpointWithParams .Replace(Constants.HttpIndexPlaceholder, indexName, StringComparison.Ordinal) .Replace(Constants.HttpDocumentIdPlaceholder, record.GetDocumentId(), StringComparison.Ordinal) .Replace(Constants.HttpFilenamePlaceholder, record.GetFileName(), StringComparison.Ordinal); - - var webPageUrl = record.GetPayloadValue(Constants.ReservedPayloadUrlField, log)?.ToString(); - - return string.IsNullOrWhiteSpace(webPageUrl) ? fileDownloadUrl : webPageUrl; } /// diff --git a/service/Core/Search/SearchClient.cs b/service/Core/Search/SearchClient.cs index a3bebaa51..5e71c569d 100644 --- a/service/Core/Search/SearchClient.cs +++ b/service/Core/Search/SearchClient.cs @@ -76,115 +76,39 @@ public async Task SearchAsync( { if (limit <= 0) { limit = this._config.MaxMatchesCount; } - var result = new SearchResult - { - Query = query, - Results = [] - }; + var result = SearchClientResult.SearchResultInstance(query, limit); if (string.IsNullOrWhiteSpace(query) && (filters == null || filters.Count == 0)) { this._log.LogWarning("No query or filters provided"); - return result; + return result.SearchResult; } +#pragma warning disable CA2254 + this._log.LogTrace(string.IsNullOrEmpty(query) + ? $"Fetching relevant memories by similarity, min relevance {minRelevance}" + : "Fetching relevant memories by filtering only, no vector search"); +#pragma warning restore CA2254 - var list = new List<(MemoryRecord memory, double relevance)>(); - if (!string.IsNullOrEmpty(query)) - { - this._log.LogTrace("Fetching relevant memories by similarity, min relevance {0}", minRelevance); - IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( - index: index, - text: query, - filters: filters, - minRelevance: minRelevance, - limit: limit, - withEmbeddings: false, - cancellationToken: cancellationToken); - - // Memories are sorted by relevance, starting from the most relevant - await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false)) - { - list.Add((memory, relevance)); - } - } - else - { - this._log.LogTrace("Fetching relevant memories by filtering"); - IAsyncEnumerable matches = this._memoryDb.GetListAsync( - index: index, - filters: filters, - limit: limit, - withEmbeddings: false, - cancellationToken: cancellationToken); - - await foreach (MemoryRecord memory in matches.ConfigureAwait(false)) - { - list.Add((memory, float.MinValue)); - } - } + IAsyncEnumerable<(MemoryRecord, double)> matches = string.IsNullOrEmpty(query) + ? this._memoryDb.GetListAsync(index, filters, limit, false, cancellationToken).Select(memoryRecord => (memoryRecord, double.MinValue)) + : this._memoryDb.GetSimilarListAsync(index, text: query, filters, minRelevance, limit, false, cancellationToken); - // Memories are sorted by relevance, starting from the most relevant - foreach ((MemoryRecord memory, double relevance) in list) + await foreach ((MemoryRecord memoryRecord, double recordRelevance) in matches.ConfigureAwait(false).WithCancellation(cancellationToken)) { - // Note: a document can be composed by multiple files - string documentId = memory.GetDocumentId(this._log); - - // Identify the file in case there are multiple files - string fileId = memory.GetFileId(this._log); - - // Note: this is not a URL and perhaps could be dropped. For now it acts as a unique identifier. See also SourceUrl. - string linkToFile = $"{index}/{documentId}/{fileId}"; - - var partitionText = memory.GetPartitionText(this._log).Trim(); - if (string.IsNullOrEmpty(partitionText)) - { - this._log.LogError("The document partition is empty, doc: {0}", memory.Id); - continue; - } - - // Relevance is `float.MinValue` when search uses only filters and no embeddings (see code above) - if (relevance > float.MinValue) { this._log.LogTrace("Adding result with relevance {0}", relevance); } - - // If the file is already in the list of citations, only add the partition - var citation = result.Results.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.Results.Add(citation); - } + result.State = SearchState.Continue; + result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance); - // Add the partition to the list of citations - citation.Index = index; - citation.DocumentId = documentId; - citation.FileId = fileId; - citation.Link = linkToFile; - citation.SourceContentType = memory.GetFileContentType(this._log); - citation.SourceName = memory.GetFileName(this._log); - citation.SourceUrl = memory.GetWebPageUrl(index); + if (result.State == SearchState.SkipRecord) { continue; } - citation.Partitions.Add(new Citation.Partition - { - Text = partitionText, - Relevance = (float)relevance, - PartitionNumber = memory.GetPartitionNumber(this._log), - SectionNumber = memory.GetSectionNumber(), - LastUpdate = memory.GetLastUpdate(), - Tags = memory.Tags, - }); - - // In cases where a buggy storage connector is returning too many records - if (result.Results.Count >= this._config.MaxMatchesCount) - { - break; - } + if (result.State == SearchState.Stop) { break; } } - if (result.Results.Count == 0) + if (result.SearchResult.Results.Count == 0) { this._log.LogDebug("No memories found"); } - return result; + return result.SearchResult; } /// @@ -198,38 +122,28 @@ public async Task AskAsync( { string emptyAnswer = context.GetCustomEmptyAnswerTextOrDefault(this._config.EmptyAnswer); string answerPrompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); - string factTemplate = context.GetCustomRagFactTemplateOrDefault(this._config.FactTemplate); int limit = context.GetCustomRagMaxMatchesCountOrDefault(this._config.MaxMatchesCount); - if (!factTemplate.EndsWith('\n')) { factTemplate += "\n"; } + var maxTokens = this._config.MaxAskPromptSize > 0 + ? this._config.MaxAskPromptSize + : this._textGenerator.MaxTokenTotal; - var noAnswerFound = new MemoryAnswer - { - Question = question, - NoResult = true, - Result = emptyAnswer, - }; + SearchClientResult result = SearchClientResult.AskResultInstance( + question: question, + emptyAnswer: emptyAnswer, + maxGroundingFacts: limit, + tokensAvailable: maxTokens + - this._textGenerator.CountTokens(answerPrompt) + - this._textGenerator.CountTokens(question) + - this._config.AnswerTokens + ); if (string.IsNullOrEmpty(question)) { this._log.LogWarning("No question provided"); - noAnswerFound.NoResultReason = "No question provided"; - return noAnswerFound; + return result.AskResult; } - var facts = new StringBuilder(); - var maxTokens = this._config.MaxAskPromptSize > 0 - ? this._config.MaxAskPromptSize - : this._textGenerator.MaxTokenTotal; - var tokensAvailable = maxTokens - - this._textGenerator.CountTokens(answerPrompt) - - this._textGenerator.CountTokens(question) - - this._config.AnswerTokens; - - var factsUsedCount = 0; - var factsAvailableCount = 0; - var answer = noAnswerFound; - this._log.LogTrace("Fetching relevant memories"); IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( index: index, @@ -240,107 +154,170 @@ public async Task AskAsync( withEmbeddings: false, cancellationToken: cancellationToken); + string factTemplate = context.GetCustomRagFactTemplateOrDefault(this._config.FactTemplate); + if (!factTemplate.EndsWith('\n')) { factTemplate += "\n"; } + // Memories are sorted by relevance, starting from the most relevant - await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false)) + await foreach ((MemoryRecord memoryRecord, double recordRelevance) in matches.ConfigureAwait(false)) { - // Note: a document can be composed by multiple files - string documentId = memory.GetDocumentId(this._log); + result.State = SearchState.Continue; + result = this.ProcessMemoryRecord(result, index, memoryRecord, recordRelevance, factTemplate); - // Identify the file in case there are multiple files - string fileId = memory.GetFileId(this._log); + if (result.State == SearchState.SkipRecord) { continue; } - // Note: this is not a URL and perhaps could be dropped. For now it acts as a unique identifier. See also SourceUrl. - string linkToFile = $"{index}/{documentId}/{fileId}"; + if (result.State == SearchState.Stop) { break; } + } - string fileName = memory.GetFileName(this._log); + return await this.GenerateAnswerAsync(question, result, minRelevance, context, cancellationToken).ConfigureAwait(false); + } - string webPageUrl = memory.GetWebPageUrl(index); + /// + /// Process memory records for ASK and SEARCH calls + /// + /// Current state of the result + /// Memory record, e.g. text chunk + metadata + /// Memory record relevance + /// Memory index name + /// How to render the record when preparing an LLM prompt + /// Updated search result state + private SearchClientResult ProcessMemoryRecord( + SearchClientResult result, string index, MemoryRecord record, double recordRelevance, string? factTemplate = null) + { + var partitionText = record.GetPartitionText(this._log).Trim(); + if (string.IsNullOrEmpty(partitionText)) + { + this._log.LogError("The document partition is empty, doc: {0}", record.Id); + return result.SkipRecord(); + } - var partitionText = memory.GetPartitionText(this._log).Trim(); - if (string.IsNullOrEmpty(partitionText)) - { - this._log.LogError("The document partition is empty, doc: {0}", memory.Id); - continue; - } + // Note: a document can be composed by multiple files + string documentId = record.GetDocumentId(this._log); - factsAvailableCount++; + // Identify the file in case there are multiple files + string fileId = record.GetFileId(this._log); - var fact = PromptUtils.RenderFactTemplate( - template: factTemplate, + // Note: this is not a URL and perhaps could be dropped. For now it acts as a unique identifier. See also SourceUrl. + string linkToFile = $"{index}/{documentId}/{fileId}"; + + // Note: this is "content.url" when importing web pages + string fileName = record.GetFileName(this._log); + + // Link to the web page (if a web page) or link to KM web endpoint to download the file + string fileDownloadUrl = record.GetWebPageUrl(index); + + // Name of the file to show to the LLM, avoiding "content.url" + string fileNameForLLM = (fileName == "content.url" ? fileDownloadUrl : fileName); + + if (result.Mode == SearchMode.SearchMode) + { + // Relevance is `float.MinValue` when search uses only filters + if (recordRelevance > float.MinValue) { this._log.LogTrace("Adding result with relevance {0}", recordRelevance); } + } + else if (result.Mode == SearchMode.AskMode) + { + result.FactsAvailableCount++; + + string fact = PromptUtils.RenderFactTemplate( + template: factTemplate!, factContent: partitionText, - source: (fileName == "content.url" ? webPageUrl : fileName), - relevance: relevance.ToString("P1", CultureInfo.CurrentCulture), - recordId: memory.Id, - tags: memory.Tags, - metadata: memory.Payload); + source: fileNameForLLM, + relevance: recordRelevance.ToString("P1", CultureInfo.CurrentCulture), + recordId: record.Id, + tags: record.Tags, + metadata: record.Payload); // Use the partition/chunk only if there's room for it - var size = this._textGenerator.CountTokens(fact); - if (size >= tokensAvailable) + int factSizeInTokens = this._textGenerator.CountTokens(fact); + if (factSizeInTokens >= result.TokensAvailable) { // Stop after reaching the max number of tokens - break; + return result.Stop(); } - factsUsedCount++; - this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance); + result.Facts.Append(fact); + result.FactsUsedCount++; + result.TokensAvailable -= factSizeInTokens; - facts.Append(fact); - tokensAvailable -= size; + // Relevance is cosine similarity when not using hybrid search + this._log.LogTrace("Adding content #{0} with relevance {1}", result.FactsUsedCount, recordRelevance); + } - // If the file is already in the list of citations, only add the partition - var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); + Citation? citation; + if (result.Mode == SearchMode.SearchMode) + { + citation = result.SearchResult.Results.FirstOrDefault(x => x.Link == linkToFile); if (citation == null) { citation = new Citation(); - answer.RelevantSources.Add(citation); + result.SearchResult.Results.Add(citation); } - - // Add the partition to the list of citations - citation.Index = index; - citation.DocumentId = documentId; - citation.FileId = fileId; - citation.Link = linkToFile; - citation.SourceContentType = memory.GetFileContentType(this._log); - citation.SourceName = fileName; - citation.SourceUrl = memory.GetWebPageUrl(index); - - citation.Partitions.Add(new Citation.Partition - { - Text = partitionText, - Relevance = (float)relevance, - PartitionNumber = memory.GetPartitionNumber(this._log), - SectionNumber = memory.GetSectionNumber(), - LastUpdate = memory.GetLastUpdate(), - Tags = memory.Tags, - }); - - // In cases where a buggy storage connector is returning too many records - if (factsUsedCount >= this._config.MaxMatchesCount) + } + else if (result.Mode == SearchMode.AskMode) + { + // If the file is already in the list of citations, only add the partition + citation = result.AskResult.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); + if (citation == null) { - break; + citation = new Citation(); + result.AskResult.RelevantSources.Add(citation); } } + else + { + throw new ArgumentOutOfRangeException(nameof(result.Mode)); + } + + citation.Index = index; + citation.DocumentId = documentId; + citation.FileId = fileId; + citation.Link = linkToFile; + citation.SourceContentType = record.GetFileContentType(this._log); + citation.SourceName = fileName; + citation.SourceUrl = fileDownloadUrl; + citation.Partitions.Add(new Citation.Partition + { + Text = partitionText, + Relevance = (float)recordRelevance, + PartitionNumber = record.GetPartitionNumber(this._log), + SectionNumber = record.GetSectionNumber(), + LastUpdate = record.GetLastUpdate(), + Tags = record.Tags, + }); + + // Stop when reaching the max number of results or facts. This acts also as + // a protection against storage connectors disregarding 'limit' and returning too many records. + if ((result.Mode == SearchMode.SearchMode && result.SearchResult.Results.Count >= result.MaxRecordCount) + || (result.Mode == SearchMode.AskMode && result.FactsUsedCount >= result.MaxRecordCount)) + { + return result.Stop(); + } - if (factsAvailableCount > 0 && factsUsedCount == 0) + return result; + } + + private async Task GenerateAnswerAsync( + string question, SearchClientResult result, double minRelevance, IContext? context, CancellationToken cancellationToken) + { + if (result.FactsAvailableCount > 0 && result.FactsUsedCount == 0) { this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - noAnswerFound.NoResultReason = "Unable to use memories"; - return noAnswerFound; + result.AskResult.NoResultReason = "Unable to use memories"; + return result.AskResult; } - if (factsUsedCount == 0) + if (result.FactsUsedCount == 0) { this._log.LogWarning("No memories available (min relevance: {0})", minRelevance); - noAnswerFound.NoResultReason = "No memories available"; - return noAnswerFound; + result.AskResult.NoResultReason = "No memories available"; + return result.AskResult; } + // Collect the LLM output var text = new StringBuilder(); var charsGenerated = 0; var watch = new Stopwatch(); watch.Restart(); - await foreach (var x in this.GenerateAnswer(question, facts.ToString(), context, cancellationToken).ConfigureAwait(false)) + await foreach (var x in this.GenerateAnswerTokensAsync(question, result.Facts.ToString(), context, cancellationToken).ConfigureAwait(false)) { text.Append(x); @@ -353,12 +330,13 @@ public async Task AskAsync( watch.Stop(); - answer.Result = text.ToString(); - this._log.LogSensitive("Answer: {0}", answer.Result); - answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer); - if (answer.NoResult) + // Finalize the answer, checking if it's empty + result.AskResult.Result = text.ToString(); + this._log.LogSensitive("Answer: {0}", result.AskResult.Result); + result.AskResult.NoResult = ValueIsEquivalentTo(result.AskResult.Result, this._config.EmptyAnswer); + if (result.AskResult.NoResult) { - answer.NoResultReason = "No relevant memories found"; + result.AskResult.NoResultReason = "No relevant memories found"; this._log.LogTrace("Answer generated in {0} msecs. No relevant memories found", watch.ElapsedMilliseconds); } else @@ -366,22 +344,23 @@ public async Task AskAsync( this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); } + // Validate the LLM output if (this._contentModeration != null && this._config.UseContentModeration) { - var isSafe = await this._contentModeration.IsSafeAsync(answer.Result, cancellationToken).ConfigureAwait(false); + var isSafe = await this._contentModeration.IsSafeAsync(result.AskResult.Result, cancellationToken).ConfigureAwait(false); if (!isSafe) { this._log.LogWarning("Unsafe answer detected. Returning error message instead."); - this._log.LogSensitive("Unsafe answer: {0}", answer.Result); - answer.NoResultReason = "Content moderation failure"; - answer.Result = this._config.ModeratedAnswer; + this._log.LogSensitive("Unsafe answer: {0}", result.AskResult.Result); + result.AskResult.NoResultReason = "Content moderation failure"; + result.AskResult.Result = this._config.ModeratedAnswer; } } - return answer; + return result.AskResult; } - private IAsyncEnumerable GenerateAnswer(string question, string facts, IContext? context, CancellationToken token) + private IAsyncEnumerable GenerateAnswerTokensAsync(string question, string facts, IContext? context, CancellationToken cancellationToken) { string prompt = context.GetCustomRagPromptOrDefault(this._answerPrompt); int maxTokens = context.GetCustomRagMaxTokensOrDefault(this._config.AnswerTokens); @@ -415,7 +394,7 @@ private IAsyncEnumerable GenerateAnswer(string question, string facts, I this._log.LogSensitive("Prompt: {0}", prompt); } - return this._textGenerator.GenerateTextAsync(prompt, options, token); + return this._textGenerator.GenerateTextAsync(prompt, options, cancellationToken); } private static bool ValueIsEquivalentTo(string value, string target) diff --git a/service/Core/Search/SearchClientResult.cs b/service/Core/Search/SearchClientResult.cs new file mode 100644 index 000000000..605055970 --- /dev/null +++ b/service/Core/Search/SearchClientResult.cs @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; + +namespace Microsoft.KernelMemory.Search; + +internal enum SearchMode +{ + SearchMode = 0, + AskMode = 1, +} + +internal enum SearchState +{ + Continue = 0, + SkipRecord = 1, + Stop = 2 +} + +internal class SearchClientResult +{ + public SearchMode Mode { get; private init; } + public SearchState State { get; set; } + + // Use by in Search and Ask mode + public MemoryAnswer AskResult { get; private init; } = new(); + public int MaxRecordCount { get; private init; } + + // Use by Ask mode + public SearchResult SearchResult { get; private init; } = new(); + public StringBuilder Facts { get; } = new(); + public int FactsAvailableCount { get; set; } + public int FactsUsedCount { get; set; } + public int TokensAvailable { get; set; } + + /// + /// Create new instance in Ask mode + /// + public static SearchClientResult AskResultInstance(string question, string emptyAnswer, int maxGroundingFacts, int tokensAvailable) + { + return new SearchClientResult + { + Mode = SearchMode.AskMode, + TokensAvailable = tokensAvailable, + MaxRecordCount = maxGroundingFacts, + AskResult = new MemoryAnswer + { + Question = question, + NoResult = true, + NoResultReason = "No question provided", + Result = emptyAnswer, + } + }; + } + + /// + /// Create new instance in Search mode + /// + public static SearchClientResult SearchResultInstance(string query, int maxSearchResults) + { + return new SearchClientResult + { + Mode = SearchMode.SearchMode, + MaxRecordCount = maxSearchResults, + SearchResult = new SearchResult + { + Query = query, + Results = [] + } + }; + } + + /// + /// Tell search client to skip the current memory record + /// + public SearchClientResult SkipRecord() + { + this.State = SearchState.SkipRecord; + return this; + } + + /// + /// Tell search client to stop processing records and return a final result + /// + public SearchClientResult Stop() + { + this.State = SearchState.Stop; + return this; + } + + // Force factory methods + private SearchClientResult() { } +}