Skip to content

Commit

Permalink
fix: correct types in withContextAwareness
Browse files Browse the repository at this point in the history
  • Loading branch information
erik-balfe committed Oct 26, 2024
1 parent 57d73f9 commit 3a2788f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 57 deletions.
23 changes: 5 additions & 18 deletions packages/llamaindex/src/agent/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,8 @@
import {
AnthropicAgent,
type AnthropicAgentParams,
} from "@llamaindex/anthropic";
import {
withContextAwareness,
type ContextAwareConfig,
} from "./contextAwareMixin.js";
import { AnthropicAgent } from "@llamaindex/anthropic";
import { withContextAwareness } from "./contextAwareMixin.js";

export type AnthropicContextAwareAgentParams = AnthropicAgentParams &
ContextAwareConfig;

export class AnthropicContextAwareAgent extends withContextAwareness(
AnthropicAgent,
) {
constructor(params: AnthropicContextAwareAgentParams) {
super(params);
}
}
const ContextAwareAnthropicAgent = withContextAwareness(AnthropicAgent);
export type { ContextAwareConfig } from "./contextAwareMixin.js";
export { ContextAwareAnthropicAgent as AnthropicContextAwareAgent };

export * from "@llamaindex/anthropic";
57 changes: 31 additions & 26 deletions packages/llamaindex/src/agent/contextAwareMixin.ts
Original file line number Diff line number Diff line change
@@ -1,62 +1,67 @@
import { type AgentRunner } from "@llamaindex/core/agent";
import {
AnthropicAgent,
type AnthropicAgentParams,
} from "@llamaindex/anthropic";
import type {
NonStreamingChatEngineParams,
StreamingChatEngineParams,
} from "@llamaindex/core/chat-engine";
import type { ChatMessage, LLM, MessageContent } from "@llamaindex/core/llms";
import type { MessageContent } from "@llamaindex/core/llms";
import type { BaseRetriever } from "@llamaindex/core/retriever";
import type { NodeWithScore } from "@llamaindex/core/schema";
import { EngineResponse, MetadataMode } from "@llamaindex/core/schema";

type Constructor<T = {}> = new (...args: any[]) => T;
import { OpenAIAgent, type OpenAIAgentParams } from "@llamaindex/openai";

export interface ContextAwareConfig {
contextRetriever: BaseRetriever;
}

export interface ContextAwareAgentRunner extends AgentRunner<LLM> {
export interface ContextAwareState {
contextRetriever: BaseRetriever;
retrievedContext: string | null;
retrieveContext(query: MessageContent): Promise<string>;
injectContext(context: string): Promise<void>;
}

export type SupportedAgent = typeof OpenAIAgent | typeof AnthropicAgent;
export type AgentParams<T> = T extends typeof OpenAIAgent
? OpenAIAgentParams
: T extends typeof AnthropicAgent
? AnthropicAgentParams
: never;

/**
* ContextAwareAgentRunner enhances the base AgentRunner with the ability to retrieve and inject relevant context
* for each query. This allows the agent to access and utilize appropriate information from a given index or retriever,
* providing more informed and context-specific responses to user queries.
*/
export function withContextAwareness<T extends Constructor<AgentRunner<LLM>>>(
BaseClass: T,
) {
return class extends BaseClass implements ContextAwareAgentRunner {
contextRetriever: BaseRetriever;
retrievedContext: string | null = null;

constructor(...args: any[]) {
super(...args);
const config = args[args.length - 1] as ContextAwareConfig;
this.contextRetriever = config.contextRetriever;
}
export function withContextAwareness<T extends SupportedAgent>(Base: T) {
return class ContextAwareAgent extends Base {
public readonly contextRetriever: BaseRetriever;
public retrievedContext: string | null = null;
public declare chatHistory: T extends typeof OpenAIAgent
? OpenAIAgent["chatHistory"]
: T extends typeof AnthropicAgent
? AnthropicAgent["chatHistory"]
: never;

createStore(): object {
return {};
constructor(params: AgentParams<T> & ContextAwareConfig) {
super(params);
this.contextRetriever = params.contextRetriever;
}

async retrieveContext(query: MessageContent): Promise<string> {
const nodes = await this.contextRetriever.retrieve({ query });
return nodes
.map((node: NodeWithScore) => node.node.getContent(MetadataMode.NONE))
.map((node) => node.node.getContent(MetadataMode.NONE))
.join("\n");
}

async injectContext(context: string): Promise<void> {
const chatHistory = (this as any).chatHistory as ChatMessage[];
const systemMessage = chatHistory.find((msg) => msg.role === "system");
const systemMessage = this.chatHistory.find(
(msg) => msg.role === "system",
);
if (systemMessage) {
systemMessage.content = `${context}\n\n${systemMessage.content}`;
} else {
chatHistory.unshift({ role: "system", content: context });
this.chatHistory.unshift({ role: "system", content: context });
}
}

Expand Down
18 changes: 5 additions & 13 deletions packages/llamaindex/src/agent/openai.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
import { OpenAIAgent, type OpenAIAgentParams } from "@llamaindex/openai";
import {
withContextAwareness,
type ContextAwareConfig,
} from "./contextAwareMixin.js";
import { OpenAIAgent } from "@llamaindex/openai";
import { withContextAwareness } from "./contextAwareMixin.js";

export type OpenAIContextAwareAgentParams = OpenAIAgentParams &
ContextAwareConfig;

export class OpenAIContextAwareAgent extends withContextAwareness(OpenAIAgent) {
constructor(params: OpenAIContextAwareAgentParams) {
super(params);
}
}
const ContextAwareOpenAIAgent = withContextAwareness(OpenAIAgent);
export type { ContextAwareConfig } from "./contextAwareMixin.js";
export { ContextAwareOpenAIAgent as OpenAIContextAwareAgent };

export * from "@llamaindex/openai";

0 comments on commit 3a2788f

Please sign in to comment.