Skip to content

Commit

Permalink
[Obs AI Assistant] Add uuid to knowledge base entries to avoid overwr…
Browse files Browse the repository at this point in the history
…iting accidentally
  • Loading branch information
sorenlouv committed Aug 22, 2024
1 parent 8d4704f commit e627b85
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ export type ConversationUpdateRequest = ConversationRequestBase & {

export interface KnowledgeBaseEntry {
'@timestamp': string;
id: string;
id: string; // this is a unique ID generated by the client
doc_id: string; // this is the human readable ID generated by the LLM and used by the LLM to update existing entries
text: string;
doc_id: string;
confidence: 'low' | 'medium' | 'high';
is_correction: boolean;
type?: 'user_instruction' | 'contextual';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ const schema: RootSchema<RecallRanking> = {
},
};

export const RecallRankingEventType = 'observability_ai_assistant_recall_ranking';
export const recallRankingEventType = 'observability_ai_assistant_recall_ranking';

export const recallRankingEvent: EventTypeOpts<RecallRanking> = {
eventType: RecallRankingEventType,
eventType: recallRankingEventType,
schema,
};
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@ export function registerSummarizationFunction({
},
},
(
{ arguments: { id, text, is_correction: isCorrection, confidence, public: isPublic } },
{ arguments: { id: docId, text, is_correction: isCorrection, confidence, public: isPublic } },
signal
) => {
return client
.addKnowledgeBaseEntry({
entry: {
doc_id: id,
doc_id: docId,
role: KnowledgeBaseEntryRole.AssistantSummarization,
id,
text,
is_correction: isCorrection,
type: KnowledgeBaseType.Contextual,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ const functionSummariseRoute = createObservabilityAIAssistantServerRoute({
return client.addKnowledgeBaseEntry({
entry: {
confidence,
id,
doc_id: id,
is_correction: isCorrection,
type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ const saveKnowledgeBaseUserInstruction = createObservabilityAIAssistantServerRou
const { id, text, public: isPublic } = resources.params.body;
return client.addKnowledgeBaseEntry({
entry: {
id,
doc_id: id,
text,
public: isPublic,
Expand Down Expand Up @@ -195,7 +194,6 @@ const saveKnowledgeBaseEntry = createObservabilityAIAssistantServerRoute({

return client.addKnowledgeBaseEntry({
entry: {
id,
text,
doc_id: id,
confidence: confidence ?? 'high',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ export class ObservabilityAIAssistantClient {
addKnowledgeBaseEntry = async ({
entry,
}: {
entry: Omit<KnowledgeBaseEntry, '@timestamp'>;
entry: Omit<KnowledgeBaseEntry, '@timestamp' | 'id'>;
}): Promise<void> => {
return this.dependencies.knowledgeBaseService.addEntry({
namespace: this.dependencies.namespace,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import pRetry from 'p-retry';
import { map, orderBy } from 'lodash';
import { encode } from 'gpt-tokenizer';
import { MlTrainedModelDeploymentNodesStats } from '@elastic/elasticsearch/lib/api/types';
import { v4 } from 'uuid';
import {
INDEX_QUEUED_DOCUMENTS_TASK_ID,
INDEX_QUEUED_DOCUMENTS_TASK_TYPE,
Expand All @@ -38,6 +39,7 @@ interface Dependencies {

export interface RecalledEntry {
id: string;
doc_id: string;
text: string;
score: number | null;
is_correction?: boolean;
Expand Down Expand Up @@ -368,13 +370,13 @@ export class KnowledgeBaseService {
};

const response = await this.dependencies.esClient.asInternalUser.search<
Pick<KnowledgeBaseEntry, 'text' | 'is_correction' | 'labels'>
Pick<KnowledgeBaseEntry, 'text' | 'is_correction' | 'labels' | 'doc_id'>
>({
index: [resourceNames.aliases.kb],
query: esQuery,
size: 20,
_source: {
includes: ['text', 'is_correction', 'labels'],
includes: ['text', 'is_correction', 'labels', 'doc_id'],
},
});

Expand Down Expand Up @@ -598,27 +600,68 @@ export class KnowledgeBaseService {
return res.hits.hits[0]?._source?.doc_id;
};

getUuidFromHumanReadableId = async ({
docId,
user,
namespace,
}: {
docId: string;
user?: { name: string; id?: string };
namespace?: string;
}) => {
const query = {
bool: {
filter: [
{ term: { doc_id: docId } },

// exclude user instructions
{ bool: { must_not: { term: { type: KnowledgeBaseType.UserInstruction } } } },

// restrict access to user's own entries
...getAccessQuery({ user, namespace }),
],
},
};

const response = await this.dependencies.esClient.asInternalUser.search<KnowledgeBaseEntry>({
size: 1,
index: resourceNames.aliases.kb,
query,
_source: false,
});

const id = response.hits.hits[0]?._id ?? v4();

return id;
};

addEntry = async ({
entry: { id, ...document },
entry,
user,
namespace,
}: {
entry: Omit<KnowledgeBaseEntry, '@timestamp'>;
entry: Omit<KnowledgeBaseEntry, '@timestamp' | 'id'>;
user?: { name: string; id?: string };
namespace?: string;
}): Promise<void> => {
let id = '';

// for now we want to limit the number of user instructions to 1 per user
if (document.type === KnowledgeBaseType.UserInstruction) {
if (entry.type === KnowledgeBaseType.UserInstruction) {
const existingId = await this.getExistingUserInstructionId({
isPublic: document.public,
isPublic: entry.public,
user,
namespace,
});

if (existingId) {
id = existingId;
document.doc_id = existingId;
entry.doc_id = existingId;
}

// override previous id if it exists
} else {
id = await this.getUuidFromHumanReadableId({ docId: entry.doc_id, user, namespace });
}

try {
Expand All @@ -627,7 +670,7 @@ export class KnowledgeBaseService {
id,
document: {
'@timestamp': new Date().toISOString(),
...document,
...entry,
user,
namespace,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@

import type { Logger } from '@kbn/logging';
import { AnalyticsServiceStart } from '@kbn/core/server';
import { scoreSuggestions } from './score_suggestions';
import type { Message } from '../../../common';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import type { FunctionCallChatFunction } from '../../service/types';
import { retrieveSuggestions } from './retrieve_suggestions';
import { scoreSuggestions } from './score_suggestions';
import type { RetrievedSuggestion } from './types';
import { RecallRanking, RecallRankingEventType } from '../../analytics/recall_ranking';
import { RecallRanking, recallRankingEventType } from '../../analytics/recall_ranking';

export interface RecalledSuggestion {
id: string;
docId: string;
text: string;
score: number | null;
}

export async function recallAndScore({
recall,
Expand All @@ -34,19 +39,24 @@ export async function recallAndScore({
logger: Logger;
signal: AbortSignal;
}): Promise<{
relevantDocuments?: RetrievedSuggestion[];
relevantDocuments?: RecalledSuggestion[];
scores?: Array<{ id: string; score: number }>;
suggestions: RetrievedSuggestion[];
suggestions: RecalledSuggestion[];
}> {
const queries = [
{ text: userPrompt, boost: 3 },
{ text: context, boost: 1 },
].filter((query) => query.text.trim());

const suggestions = await retrieveSuggestions({
recall,
queries,
});
const { entries: recalledEntries } = await recall({ queries });
const suggestions: RecalledSuggestion[] = recalledEntries.map(
({ id, doc_id: docId, text, score }) => ({
id,
docId,
text,
score,
})
);

if (!suggestions.length) {
return {
Expand All @@ -67,7 +77,7 @@ export async function recallAndScore({
chat,
});

analytics.reportEvent<RecallRanking>(RecallRankingEventType, {
analytics.reportEvent<RecallRanking>(recallRankingEventType, {
prompt: queries.map((query) => query.text).join('\n\n'),
scoredDocuments: suggestions.map((suggestion) => {
const llmScore = scores.find((score) => score.id === suggestion.id);
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
* 2.0.
*/
import * as t from 'io-ts';
import { omit } from 'lodash';
import { Logger } from '@kbn/logging';
import dedent from 'dedent';
import { lastValueFrom } from 'rxjs';
import { decodeOrThrow, jsonRt } from '@kbn/io-ts-utils';
import { concatenateChatCompletionChunks, Message, MessageRole } from '../../../common';
import type { FunctionCallChatFunction } from '../../service/types';
import type { RetrievedSuggestion } from './types';
import { parseSuggestionScores } from './parse_suggestion_scores';
import { ShortIdTable } from '../../../common/utils/short_id_table';
import { RecalledSuggestion } from './recall_and_score';

const scoreFunctionRequestRt = t.type({
message: t.type({
Expand All @@ -38,25 +36,17 @@ export async function scoreSuggestions({
signal,
logger,
}: {
suggestions: RetrievedSuggestion[];
suggestions: RecalledSuggestion[];
messages: Message[];
userPrompt: string;
context: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}): Promise<{
relevantDocuments: RetrievedSuggestion[];
relevantDocuments: RecalledSuggestion[];
scores: Array<{ id: string; score: number }>;
}> {
const shortIdTable = new ShortIdTable();

const suggestionsWithShortId = suggestions.map((suggestion) => ({
...omit(suggestion, 'score', 'id'), // To not bias the LLM
originalId: suggestion.id,
shortId: shortIdTable.take(suggestion.id),
}));

const newUserMessageContent =
dedent(`Given the following question, score the documents that are relevant to the question. on a scale from 0 to 7,
0 being completely irrelevant, and 7 being extremely relevant. Information is relevant to the question if it helps in
Expand All @@ -76,10 +66,7 @@ export async function scoreSuggestions({
Documents:
${JSON.stringify(
suggestionsWithShortId.map((suggestion) => ({
id: suggestion.shortId,
content: suggestion.text,
})),
suggestions.map(({ id, docId: title, text }) => ({ id, title, text })),
null,
2
)}`);
Expand Down Expand Up @@ -127,15 +114,7 @@ export async function scoreSuggestions({
scoreFunctionRequest.message.function_call.arguments
);

const scores = parseSuggestionScores(scoresAsString).map(({ id, score }) => {
const originalSuggestion = suggestionsWithShortId.find(
(suggestion) => suggestion.shortId === id
);
return {
originalId: originalSuggestion?.originalId,
score,
};
});
const scores = parseSuggestionScores(scoresAsString);

if (scores.length === 0) {
// seemingly invalid or no scores, return all
Expand All @@ -145,11 +124,11 @@ export async function scoreSuggestions({
const suggestionIds = suggestions.map((document) => document.id);

const relevantDocumentIds = scores
.filter((document) => suggestionIds.includes(document.originalId ?? '')) // Remove hallucinated documents
.filter((document) => suggestionIds.includes(document.id ?? '')) // Remove hallucinated documents
.filter((document) => document.score > 4)
.sort((a, b) => b.score - a.score)
.slice(0, 5)
.map((document) => document.originalId);
.map((document) => document.id);

const relevantDocuments = suggestions.filter((suggestion) =>
relevantDocumentIds.includes(suggestion.id)
Expand All @@ -159,6 +138,6 @@ export async function scoreSuggestions({

return {
relevantDocuments,
scores: scores.map((score) => ({ id: score.originalId!, score: score.score })),
scores: scores.map((score) => ({ id: score.id, score: score.score })),
};
}

This file was deleted.

0 comments on commit e627b85

Please sign in to comment.