Skip to content

Commit

Permalink
refactor: ThreadList (#1236)
Browse files Browse the repository at this point in the history
* refactor: ThreadList

* Update packages/react/src/runtimes/external-store/ExternalStoreRuntimeCore.tsx

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* cr fixes

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
Yonom and coderabbitai[bot] authored Dec 7, 2024
1 parent 70ea4a8 commit d33a293
Show file tree
Hide file tree
Showing 23 changed files with 392 additions and 409 deletions.
35 changes: 20 additions & 15 deletions packages/react-langgraph/src/useLangGraphRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { useLangGraphMessages } from "./useLangGraphMessages";
import { SimpleImageAttachmentAdapter } from "@assistant-ui/react";
import { AttachmentAdapter } from "@assistant-ui/react";
import { AppendMessage } from "@assistant-ui/react";
import { ExternalStoreAdapter } from "@assistant-ui/react";

const getPendingToolCalls = (messages: LangChainMessage[]) => {
const pendingToolCalls = new Map<string, LangChainToolCall>();
Expand Down Expand Up @@ -117,26 +118,30 @@ export const useLangGraphRuntime = ({
if (unstable_allowImageAttachments)
attachments = new SimpleImageAttachmentAdapter();

const threadList: NonNullable<
ExternalStoreAdapter["adapters"]
>["threadList"] = {
threadId,
onSwitchToNewThread: !onSwitchToNewThread
? undefined
: async () => {
await onSwitchToNewThread();
setMessages([]);
},
onSwitchToThread: !onSwitchToThread
? undefined
: async (threadId) => {
const { messages } = await onSwitchToThread(threadId);
setMessages(messages);
},
};

return useExternalStoreRuntime({
isRunning,
messages: threadMessages,
adapters: {
attachments,
threadList: {
threadId,
onSwitchToNewThread: !onSwitchToNewThread
? undefined
: async () => {
await onSwitchToNewThread();
setMessages([]);
},
onSwitchToThread: !onSwitchToThread
? undefined
: async (threadId) => {
const { messages } = await onSwitchToThread(threadId);
setMessages(messages);
},
},
threadList,
},
onNew: (msg) => {
const cancellations =
Expand Down
55 changes: 28 additions & 27 deletions packages/react-playground/src/lib/playground-runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {
ThreadSuggestion,
ThreadRuntime,
AssistantRuntime,
ThreadMetadataRuntimeCore,
} from "@assistant-ui/react";
import { LanguageModelV1FunctionTool } from "@ai-sdk/provider";
import { useMemo, useState } from "react";
Expand All @@ -37,7 +36,6 @@ const {
DefaultThreadComposerRuntimeCore,
AssistantRuntimeImpl,
ThreadRuntimeImpl,
LocalThreadMetadataRuntimeCore,
} = INTERNAL;

const makeModelConfigStore = () =>
Expand All @@ -51,9 +49,7 @@ const makeModelConfigStore = () =>
config: {},
}));

type PlaygroundThreadFactory = (
threadId: string,
) => PlaygroundThreadRuntimeCore;
type PlaygroundThreadFactory = () => PlaygroundThreadRuntimeCore;

const EMPTY_ARRAY = [] as never[];

Expand All @@ -62,36 +58,46 @@ class PlaygroundThreadListRuntimeCore
{
private _mainThread: PlaygroundThreadRuntimeCore;

public get mainThread() {
return this._mainThread;
public get mainThreadId() {
return "default";
}

public get newThread() {
public get newThreadId() {
return undefined;
}

public get threads() {
public get threadIds() {
return EMPTY_ARRAY;
}

public get archivedThreads() {
public get archivedThreadIds() {
return EMPTY_ARRAY;
}

constructor(private threadFactory: PlaygroundThreadFactory) {
this._mainThread = this.threadFactory(generateId());
this._mainThread = this.threadFactory();
}

getThreadMetadataById(): never {
throw new Error("Method not implemented.");
public getMainThreadRuntimeCore() {
return this._mainThread;
}

public getItemById(id: string) {
if (id !== "default") return undefined;

return {
threadId: "default",
state: "regular",
runtime: this._mainThread,
} as const;
}

public switchToThread(): Promise<void> {
throw new Error("Method not implemented.");
}

public async switchToNewThread(): Promise<void> {
this._mainThread = this.threadFactory(generateId());
this._mainThread = this.threadFactory();
this.notifySubscribers();
}

Expand Down Expand Up @@ -129,10 +135,9 @@ class PlaygroundRuntimeCore extends BaseAssistantRuntimeCore {
) {
super();

this.threadList = new PlaygroundThreadListRuntimeCore((threadId) => {
this.threadList = new PlaygroundThreadListRuntimeCore(() => {
const thread = new PlaygroundThreadRuntimeCore(
this._proxyConfigProvider,
threadId,
fromCoreMessages(initialMessages),
adapter,
);
Expand Down Expand Up @@ -162,10 +167,8 @@ const CAPABILITIES = Object.freeze({
const EMPTY_BRANCHES: readonly string[] = Object.freeze([]);

export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore {
private _metadata: ThreadMetadataRuntimeCore;

public get metadata() {
return this._metadata;
return { isMain: true, threadId: "default", state: "regular" } as const;
}

private _subscriptions = new Set<() => void>();
Expand Down Expand Up @@ -195,11 +198,9 @@ export class PlaygroundThreadRuntimeCore implements INTERNAL.ThreadRuntimeCore {

constructor(
configProvider: ModelConfigProvider,
threadId: string,
private _messages: ThreadMessage[],
public readonly adapter: ChatModelAdapter,
) {
this._metadata = new LocalThreadMetadataRuntimeCore(threadId);
this.configProvider.registerModelConfigProvider(configProvider);
this.configProvider.registerModelConfigProvider({
getModelConfig: () => this.useModelConfig.getState(),
Expand Down Expand Up @@ -625,8 +626,11 @@ class PlaygroundThreadRuntimeImpl
extends ThreadRuntimeImpl
implements PlaygroundThreadRuntime
{
constructor(private binding: INTERNAL.ThreadRuntimeCoreBinding) {
super(binding);
constructor(
private binding: INTERNAL.ThreadRuntimeCoreBinding,
threadListItemBinding: INTERNAL.ThreadListItemRuntimeBinding,
) {
super(binding, threadListItemBinding);
}

private _getState() {
Expand Down Expand Up @@ -688,10 +692,7 @@ class PlaygroundRuntimeImpl
public static override create(_core: PlaygroundRuntimeCore) {
return new PlaygroundRuntimeImpl(
_core,
AssistantRuntimeImpl.createMainThreadRuntime(
_core,
PlaygroundThreadRuntimeImpl,
),
PlaygroundThreadRuntimeImpl,
) as PlaygroundRuntime;
}
}
Expand Down
46 changes: 21 additions & 25 deletions packages/react/src/api/AssistantRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { AssistantRuntimeCore } from "../runtimes/core/AssistantRuntimeCore";
import { NestedSubscriptionSubject } from "./subscribable/NestedSubscriptionSubject";
import { ModelConfigProvider } from "../types/ModelConfigTypes";
import {
ThreadListItemRuntimeBinding,
ThreadRuntime,
ThreadRuntimeCoreBinding,
ThreadRuntimeImpl,
Expand Down Expand Up @@ -46,12 +47,27 @@ export class AssistantRuntimeImpl
AssistantRuntime
{
public readonly threadList;
public readonly _thread: ThreadRuntime;

protected constructor(
private readonly _core: AssistantRuntimeCore,
private readonly _thread: ThreadRuntime,
runtimeFactory: new (
binding: ThreadRuntimeCoreBinding,
threadListItemBinding: ThreadListItemRuntimeBinding,
) => ThreadRuntime = ThreadRuntimeImpl,
) {
this.threadList = new ThreadListRuntimeImpl(_core.threadList);
this._thread = new runtimeFactory(
new NestedSubscriptionSubject({
path: {
ref: "threads.main",
threadSelector: { type: "main" },
},
getState: () => _core.threadList.getMainThreadRuntimeCore(),
subscribe: (callback) => _core.threadList.subscribe(callback),
}),
this.threadList.mainThreadListItem, // TODO capture "main" threadListItem from context around useLocalRuntime / useExternalStoreRuntime
);
}

public get thread() {
Expand All @@ -70,33 +86,13 @@ export class AssistantRuntimeImpl
return this._core.registerModelConfigProvider(provider);
}

protected static createMainThreadRuntime(
_core: AssistantRuntimeCore,
CustomThreadRuntime: new (
binding: ThreadRuntimeCoreBinding,
) => ThreadRuntime = ThreadRuntimeImpl,
) {
return new CustomThreadRuntime(
new NestedSubscriptionSubject({
path: {
ref: "threads.main",
threadSelector: { type: "main" },
},
getState: () => _core.threadList.mainThread,
subscribe: (callback) => _core.threadList.subscribe(callback),
}),
);
}

public static create(
_core: AssistantRuntimeCore,
CustomThreadRuntime: new (
runtimeFactory: new (
binding: ThreadRuntimeCoreBinding,
threadListItemBinding: ThreadListItemRuntimeBinding,
) => ThreadRuntime = ThreadRuntimeImpl,
) {
return new AssistantRuntimeImpl(
_core,
AssistantRuntimeImpl.createMainThreadRuntime(_core, CustomThreadRuntime),
) as AssistantRuntime;
): AssistantRuntime {
return new AssistantRuntimeImpl(_core, runtimeFactory);
}
}
43 changes: 36 additions & 7 deletions packages/react/src/api/ThreadListItemRuntime.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import { ThreadMetadata } from "../runtimes/core/ThreadRuntimeCore";
import { Unsubscribe } from "../types";
import { ThreadListItemRuntimePath } from "./RuntimePathTypes";
import { SubscribableWithState } from "./subscribable/Subscribable";
import { ThreadListRuntimeCoreBinding } from "./ThreadListRuntime";

export type ThreadListItemState = ThreadMetadata & {
export type ThreadListItemEventType = "switched-to" | "switched-away";

export type ThreadListItemState = {
readonly isMain: boolean;

readonly id: string;

/**
* @deprecated This field was renamed to `id`. This field will be removed in 0.8.0.
*/
readonly threadId: string;

readonly state: "archived" | "regular" | "new" | "deleted";
readonly title?: string | undefined;
};

export type ThreadListItemRuntime = {
Expand All @@ -19,6 +30,11 @@ export type ThreadListItemRuntime = {
delete(): Promise<void>;

subscribe(callback: () => void): Unsubscribe;

unstable_on(
event: ThreadListItemEventType,
callback: () => void,
): Unsubscribe;
};

export type ThreadListItemStateBinding = SubscribableWithState<
Expand All @@ -42,31 +58,44 @@ export class ThreadListItemRuntimeImpl implements ThreadListItemRuntime {

public switchTo(): Promise<void> {
const state = this._core.getState();
return this._threadListBinding.switchToThread(state.threadId);
return this._threadListBinding.switchToThread(state.id);
}

public rename(newTitle: string): Promise<void> {
const state = this._core.getState();

return this._threadListBinding.rename(state.threadId, newTitle);
return this._threadListBinding.rename(state.id, newTitle);
}

public archive(): Promise<void> {
const state = this._core.getState();

return this._threadListBinding.archive(state.threadId);
return this._threadListBinding.archive(state.id);
}

public unarchive(): Promise<void> {
const state = this._core.getState();

return this._threadListBinding.unarchive(state.threadId);
return this._threadListBinding.unarchive(state.id);
}

public delete(): Promise<void> {
const state = this._core.getState();

return this._threadListBinding.delete(state.threadId);
return this._threadListBinding.delete(state.id);
}

public unstable_on(event: ThreadListItemEventType, callback: () => void) {
let prevIsMain = this._core.getState().isMain;
return this.subscribe(() => {
const newIsMain = this._core.getState().isMain;
if (prevIsMain === newIsMain) return;
prevIsMain = newIsMain;

if (event === "switched-to" && !newIsMain) return;
if (event === "switched-away" && newIsMain) return;
callback();
});
}

public subscribe(callback: () => void): Unsubscribe {
Expand Down
Loading

0 comments on commit d33a293

Please sign in to comment.