Skip to content

Commit

Permalink
Improve RAG with configurable seasrch prompt params via the API
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris committed Nov 6, 2024
1 parent d5d3d10 commit d3c4bc8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
4 changes: 3 additions & 1 deletion src/api/rest/v1/llm/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { PromptSearchService } from '../../../../services/assistants/search'
import { Utils } from "../../../../utils";
import { HotLoadProgress } from "../../../../services/data";
import { DataService } from "../../../../services/data";
import { PromptSearchServiceConfig } from "../../../../services/assistants/interfaces";
const _ = require('lodash')

export interface LLMConfig {
Expand Down Expand Up @@ -79,6 +80,7 @@ export class LLMController {
const { context, account } = await Utils.getNetworkConnectionFromRequest(req)
const did = await account.did()
const prompt = req.body.prompt
const promptConfig: PromptSearchServiceConfig = req.body.promptConfig

const {
customEndpoint,
Expand All @@ -89,7 +91,7 @@ export class LLMController {
const llm = getLLM(llmProvider, llmModel, customEndpoint)

const promptService = new PromptSearchService(did, context)
const promptResult = await promptService.prompt(prompt, llm)
const promptResult = await promptService.prompt(prompt, llm, promptConfig)

return res.json(promptResult)
} catch (error) {
Expand Down
7 changes: 2 additions & 5 deletions src/services/assistants/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ export interface PromptSearchServiceDataTypes {
}

export interface PromptSearchServiceConfig {
searchType?: PromptSearchType
maxContextLength?: number
dataTypes?: PromptSearchServiceDataTypes
promptSearchConfig?: PromptSearchLLMResponseOptional
}

export interface PromptSearchLLMResponseOptional extends Partial<PromptSearchLLMResponse> {}
promptSearchConfig?: PromptSearchLLMResponse
}
24 changes: 16 additions & 8 deletions src/services/assistants/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,22 @@ const DEFAULT_PROMPT_SEARCH_SERVICE_CONFIG: PromptSearchServiceConfig = {

export class PromptSearchService extends VeridaService {

public async prompt(prompt: string, llm: LLM, config: PromptSearchServiceConfig = {}): Promise<{
public async prompt(prompt: string, llm: LLM, config?: PromptSearchServiceConfig): Promise<{
result: string,
duration: number,
process: PromptSearchLLMResponse
}> {
const start = Date.now()

config = _.merge({}, DEFAULT_PROMPT_SEARCH_SERVICE_CONFIG, config)

const start = Date.now()
const promptSearch = new PromptSearch(llm)
const promptSearchResult = await promptSearch.search(prompt, undefined, config.promptSearchConfig)
let promptSearchResult
if (config?.promptSearchConfig) {
promptSearchResult = config.promptSearchConfig
} else {
const promptSearch = new PromptSearch(llm)
promptSearchResult = await promptSearch.search(prompt)
}

console.log(promptSearchResult)

Expand Down Expand Up @@ -86,8 +92,10 @@ export class PromptSearchService extends VeridaService {
console.log(`Searching by timeframe: ${maxDatetime} ${sort}`)
if (promptSearchResult.databases.indexOf(SearchType.EMAILS) !== -1) {
emails = await searchService.schemaByDateRange<SchemaEmail>(SearchType.EMAILS, maxDatetime, sort, config.dataTypes.emails.limit*3)
const emailShortlist = new EmailShortlist(llm)
emails = await emailShortlist.shortlist(prompt, emails, config.dataTypes.emails.limit)
if (emails.length > config.dataTypes.emails.limit) {
const emailShortlist = new EmailShortlist(llm)
emails = await emailShortlist.shortlist(prompt, emails, config.dataTypes.emails.limit)
}
}
if (promptSearchResult.databases.indexOf(SearchType.FILES) !== -1) {
files = await searchService.schemaByDateRange<SchemaFile>(SearchType.FILES, maxDatetime, sort, config.dataTypes.files.limit)
Expand Down Expand Up @@ -170,9 +178,9 @@ export class PromptSearchService extends VeridaService {
// console.log(finalPrompt)

const finalResponse = await llm.prompt(finalPrompt, undefined, false)
const duration = Date.now() - start
const duration = ((Date.now() - start) / 1000.0)

console.log(contextString)
// console.log(contextString)

return {
result: finalResponse.choices[0].message.content!,
Expand Down
6 changes: 2 additions & 4 deletions src/services/tools/promptSearch.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
const _ = require('lodash')
import { KeywordSearchTimeframe } from "../../helpers/interfaces";
import { PromptSearchLLMResponseOptional } from "../assistants/interfaces";
import { LLM } from "../llm"
import { SearchType } from "../search";

Expand Down Expand Up @@ -56,12 +55,11 @@ export class PromptSearch {
this.llm = llm
}

public async search(userPrompt: string, retries = 3, defaultResponse?: PromptSearchLLMResponseOptional): Promise<PromptSearchLLMResponse> {
public async search(userPrompt: string, retries = 3): Promise<PromptSearchLLMResponse> {
const response = await this.llm.prompt(userPrompt, systemPrompt)

try {
const searchResponse = <PromptSearchLLMResponse> JSON.parse(response.choices[0].message.content!)
return _.merge({}, searchResponse, defaultResponse ? defaultResponse : {})
return <PromptSearchLLMResponse> JSON.parse(response.choices[0].message.content!)
} catch (err: any) {
if (retries === 0) {
throw new Error(`No user data query available`)
Expand Down

0 comments on commit d3c4bc8

Please sign in to comment.