Skip to content

Commit

Permalink
feat: AssistantRuntime.newThread (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yonom authored Jun 27, 2024
1 parent 6f03b93 commit 62e9f19
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 24 deletions.
6 changes: 6 additions & 0 deletions .changeset/brown-wasps-lick.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"@assistant-ui/react-ai-sdk": patch
"@assistant-ui/react": patch
---

feat: AssistantRuntime newThread
38 changes: 33 additions & 5 deletions packages/react-ai-sdk/src/rsc/VercelRSCRuntime.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"use client";

import {
type AppendMessage,
type AssistantRuntime,
INTERNAL,
type AppendMessage,
type ReactThreadRuntime,
type ThreadMessage,
type Unsubscribe,
Expand All @@ -13,13 +12,42 @@ import type { VercelRSCAdapter } from "./VercelRSCAdapter";
import type { VercelRSCMessage } from "./VercelRSCMessage";
import { useVercelRSCSync } from "./useVercelRSCSync";

const { ProxyConfigProvider } = INTERNAL;
const { ProxyConfigProvider, BaseAssistantRuntime } = INTERNAL;

const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);

export class VercelRSCRuntime<T extends WeakKey = VercelRSCMessage>
export class VercelRSCRuntime<
T extends WeakKey = VercelRSCMessage,
> extends BaseAssistantRuntime<VercelRSCThreadRuntime<T>> {
constructor(adapter: VercelRSCAdapter<T>) {
super(new VercelRSCThreadRuntime(adapter));
}

public set adapter(adapter: VercelRSCAdapter<T>) {
this.thread.adapter = adapter;
}

public onAdapterUpdated() {
return this.thread.onAdapterUpdated();
}

public registerModelConfigProvider() {
// no-op
return () => {};
}

public newThread() {
this.thread = new VercelRSCThreadRuntime(this.thread.adapter);
}

public switchToThread() {
throw new Error("VercelRSCRuntime does not yet support switching threads");
}
}

class VercelRSCThreadRuntime<T extends WeakKey = VercelRSCMessage>
extends ProxyConfigProvider
implements AssistantRuntime, ReactThreadRuntime
implements ReactThreadRuntime
{
private useAdapter: UseBoundStore<StoreApi<{ adapter: VercelRSCAdapter<T> }>>;

Expand Down
40 changes: 32 additions & 8 deletions packages/react-ai-sdk/src/ui/VercelAIRuntime.tsx
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import type {
AssistantRuntime,
ReactThreadRuntime,
Unsubscribe,
} from "@assistant-ui/react";
import type { ReactThreadRuntime, Unsubscribe } from "@assistant-ui/react";
import type { AppendMessage, ThreadMessage } from "@assistant-ui/react";
import { INTERNAL } from "@assistant-ui/react";
import type { Message } from "ai";
Expand All @@ -13,15 +9,43 @@ import { sliceMessagesUntil } from "./utils/sliceMessagesUntil";
import { useVercelAIComposerSync } from "./utils/useVercelAIComposerSync";
import { useVercelAIThreadSync } from "./utils/useVercelAIThreadSync";

const { ProxyConfigProvider, MessageRepository } = INTERNAL;
const { ProxyConfigProvider, MessageRepository, BaseAssistantRuntime } =
INTERNAL;

const hasUpcomingMessage = (isRunning: boolean, messages: ThreadMessage[]) => {
return isRunning && messages[messages.length - 1]?.role !== "assistant";
};

export class VercelAIRuntime
export class VercelAIRuntime extends BaseAssistantRuntime<VercelAIThreadRuntime> {
constructor(vercel: VercelHelpers) {
super(new VercelAIThreadRuntime(vercel));
}

public set vercel(vercel: VercelHelpers) {
this.thread.vercel = vercel;
}

public onVercelUpdated() {
return this.thread.onVercelUpdated();
}

public registerModelConfigProvider() {
// no-op
return () => {};
}

public newThread() {
this.thread = new VercelAIThreadRuntime(this.thread.vercel);
}

public switchToThread() {
throw new Error("VercelAIRuntime does not yet support switching threads");
}
}

class VercelAIThreadRuntime
extends ProxyConfigProvider
implements AssistantRuntime, ReactThreadRuntime
implements ReactThreadRuntime
{
private _subscriptions = new Set<() => void>();
private repository = new MessageRepository();
Expand Down
1 change: 1 addition & 0 deletions packages/react/src/internal.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export { ProxyConfigProvider } from "./utils/ProxyConfigProvider";
export { MessageRepository } from "./runtime/utils/MessageRepository";
export { BaseAssistantRuntime } from "./runtime/core/BaseAssistantRuntime";
3 changes: 3 additions & 0 deletions packages/react/src/runtime/core/AssistantRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ import type { Unsubscribe } from "../../types/Unsubscribe";
import type { ThreadRuntime } from "./ThreadRuntime";

export type AssistantRuntime = ThreadRuntime & {
newThread: () => void;
switchToThread: (threadId: string) => void;

registerModelConfigProvider: (provider: ModelConfigProvider) => Unsubscribe;
};
53 changes: 53 additions & 0 deletions packages/react/src/runtime/core/BaseAssistantRuntime.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import type { AppendMessage } from "../../types/AssistantTypes";
import { type ModelConfigProvider } from "../../types/ModelConfigTypes";
import type { Unsubscribe } from "../../types/Unsubscribe";
import type { AssistantRuntime } from "./AssistantRuntime";
import { ThreadRuntime } from "./ThreadRuntime";

export abstract class BaseAssistantRuntime<TThreadRuntime extends ThreadRuntime>
implements AssistantRuntime
{
constructor(protected thread: TThreadRuntime) {}

public abstract registerModelConfigProvider(
provider: ModelConfigProvider,
): Unsubscribe;
public abstract newThread(): void;
public abstract switchToThread(threadId: string): void;

public get messages() {
return this.thread.messages;
}

public get isRunning() {
return this.thread.isRunning;
}

public getBranches(messageId: string): readonly string[] {
return this.thread.getBranches(messageId);
}

public switchToBranch(branchId: string): void {
return this.thread.switchToBranch(branchId);
}

public append(message: AppendMessage): void {
return this.thread.append(message);
}

public startRun(parentId: string | null): void {
return this.thread.startRun(parentId);
}

public cancelRun(): void {
return this.thread.cancelRun();
}

public addToolResult(toolCallId: string, result: any) {
return this.thread.addToolResult(toolCallId, result);
}

public subscribe(callback: () => void): Unsubscribe {
return this.thread.subscribe(callback);
}
}
48 changes: 38 additions & 10 deletions packages/react/src/runtime/local/LocalRuntime.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,44 @@ import {
mergeModelConfigs,
} from "../../types/ModelConfigTypes";
import type { Unsubscribe } from "../../types/Unsubscribe";
import type { AssistantRuntime } from "../core/AssistantRuntime";
import { ThreadRuntime } from "../core";
import { MessageRepository } from "../utils/MessageRepository";
import { generateId } from "../utils/idUtils";
import { BaseAssistantRuntime } from "../core/BaseAssistantRuntime";
import type { ChatModelAdapter, ChatModelRunResult } from "./ChatModelAdapter";

export class LocalRuntime implements AssistantRuntime {
export class LocalRuntime extends BaseAssistantRuntime<LocalThreadRuntime> {
private readonly _configProviders: Set<ModelConfigProvider>;

constructor(adapter: ChatModelAdapter) {
const configProviders = new Set<ModelConfigProvider>();
super(new LocalThreadRuntime(configProviders, adapter));
this._configProviders = configProviders;
}

public set adapter(adapter: ChatModelAdapter) {
this.thread.adapter = adapter;
}

registerModelConfigProvider(provider: ModelConfigProvider) {
this._configProviders.add(provider);
return () => this._configProviders.delete(provider);
}

public newThread() {
return (this.thread = new LocalThreadRuntime(
this._configProviders,
this.thread.adapter,
));
}

public switchToThread() {
throw new Error("LocalRuntime does not yet support switching threads");
}
}

class LocalThreadRuntime implements ThreadRuntime {
private _subscriptions = new Set<() => void>();
private _configProviders = new Set<ModelConfigProvider>();

private abortController: AbortController | null = null;
private repository = new MessageRepository();
Expand All @@ -27,7 +57,10 @@ export class LocalRuntime implements AssistantRuntime {
return this.abortController != null;
}

constructor(public adapter: ChatModelAdapter) {}
constructor(
private _configProviders: Set<ModelConfigProvider>,
public adapter: ChatModelAdapter,
) {}

public getBranches(messageId: string): string[] {
return this.repository.getBranches(messageId);
Expand Down Expand Up @@ -117,12 +150,7 @@ export class LocalRuntime implements AssistantRuntime {
return () => this._subscriptions.delete(callback);
}

registerModelConfigProvider(provider: ModelConfigProvider) {
this._configProviders.add(provider);
return () => this._configProviders.delete(provider);
}

addToolResult() {
throw new Error("LocalRuntime does not yet support tool results");
throw new Error("LocalRuntime does not yet support adding tool results");
}
}
2 changes: 1 addition & 1 deletion packages/react/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ export type {
ToolCallContentPartComponent,
} from "./ContentPartComponentTypes";

export type { ModelConfig } from "./ModelConfigTypes";
export type { ModelConfig, ModelConfigProvider } from "./ModelConfigTypes";

export type { Unsubscribe } from "./Unsubscribe";

0 comments on commit 62e9f19

Please sign in to comment.