Skip to content

Commit

Permalink
Add polling to handle queued + multi-step requests.
Browse files Browse the repository at this point in the history
  • Loading branch information
lublagg committed Dec 11, 2024
1 parent b3ccd54 commit 3391f6d
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/app-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"keyboardShortcut": "ctrl+?"
},
"assistant": {
"assistantId": "asst_Af8jrKYOFP4MxA9nse61yFBq",
"assistantId": "asst_xmAX5oxByssXrkBymMbcsVEm",
"instructions": "You are DAVAI, a Data Analysis through Voice and Artificial Intelligence partner. You are an intermediary for a user who is blind who wants to interact with data tables in a data analysis app named CODAP.",
"modelName": "gpt-4o-mini",
"useExisting": true
Expand Down
2 changes: 1 addition & 1 deletion src/hooks/use-chat-transcript-store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export const useChatTranscriptStore = () => {
messages: [
{
speaker: DAVAI_SPEAKER,
content: GREETING,
messageContent: {content: GREETING},
timestamp: timeStamp(),
},
],
Expand Down
118 changes: 75 additions & 43 deletions src/models/assistant-model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import { types, flow, Instance } from "mobx-state-tree";
import { Message } from "openai/resources/beta/threads/messages";
import { getAttributeList, getDataContext } from "../utils/codap-api-helpers";
import { codapInterface } from "@concord-consortium/codap-plugin-api";
import { DAVAI_SPEAKER, DEBUG_SPEAKER } from "../constants";
import { createGraph } from "../utils/codap-utils";
import { formatMessage } from "../utils/utils";
import { getTools, initLlmConnection } from "../utils/llm-utils";
import { ChatTranscriptModel } from "./chat-transcript-model";
Expand Down Expand Up @@ -102,67 +101,100 @@ export const AssistantModel = types
assistant_id: self.assistant.id,
});

// Wait for run completion and handle responses
let runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, run.id);
while (runState.status !== "completed" && runState.status !== "requires_action") {
runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, run.id);
}

if (runState.status === "requires_action") {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "User request requires action", content: formatMessage(runState)});
yield handleRequiredAction(runState, run.id);
}

const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Updated thread messages list", content: formatMessage(messages)});

const lastMessageForRun = messages.data.filter(
(msg: Message) => msg.run_id === run.id && msg.role === "assistant"
).pop();

const lastMessageContent = lastMessageForRun?.content[0]?.text?.value;
if (lastMessageContent) {
self.transcriptStore.addMessage(DAVAI_SPEAKER, {content: lastMessageContent});
} else {
self.transcriptStore.addMessage(DAVAI_SPEAKER, {content: "I'm sorry, I don't have a response for that."});
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "No content in last message", content: formatMessage(lastMessageForRun)});
}

yield pollRunState(run.id);
} catch (err) {
console.error("Failed to complete run:", err);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Failed to complete run", content: formatMessage(err)});
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Failed to complete run",
content: formatMessage(err),
});
}
});

const pollRunState: (currentRunId: string) => Promise<any> = flow(function* (currentRunId) {
let runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run state status",
content: formatMessage(runState.status),
});

const errorStates = ["failed", "cancelled", "incomplete"];

while (runState.status !== "completed" && runState.status !== "requires_action" && !errorStates.includes(runState.status)) {
yield new Promise((resolve) => setTimeout(resolve, 2000));
runState = yield self.apiConnection.beta.threads.runs.retrieve(self.thread.id, currentRunId);
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run state status",
content: formatMessage(runState.status),
});
}

if (errorStates.includes(runState.status)) {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run failed",
content: formatMessage(runState),
});
self.transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I encountered an error. Please try again.",
});
}

if (runState.status === "requires_action") {
self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run requires action",
content: formatMessage(runState),
});
yield handleRequiredAction(runState, currentRunId);
yield pollRunState(currentRunId);
}

if (runState.status === "completed") {
const messages = yield self.apiConnection.beta.threads.messages.list(self.thread.id);

const lastMessageForRun = messages.data
.filter((msg: Message) => msg.run_id === currentRunId && msg.role === "assistant")
.pop();

self.transcriptStore.addMessage(DEBUG_SPEAKER, {
description: "Run completed, assistant response",
content: formatMessage(lastMessageForRun),
});

const lastMessageContent = lastMessageForRun?.content[0]?.text?.value;
if (lastMessageContent) {
self.transcriptStore.addMessage(DAVAI_SPEAKER, { content: lastMessageContent });
} else {
self.transcriptStore.addMessage(DAVAI_SPEAKER, {
content: "I'm sorry, I don't have a response for that.",
});
}
}
});

const handleRequiredAction = flow(function* (runState, runId) {
try {
const toolOutputs = runState.required_action?.submit_tool_outputs.tool_calls
? yield Promise.all(
runState.required_action.submit_tool_outputs.tool_calls.map(flow(function* (toolCall: any) {
if (toolCall.function.name === "get_attributes") {
const { dataset } = JSON.parse(toolCall.function.arguments);
const rootCollection = (yield getDataContext(dataset)).values.collections[0];
const attributeListRes = yield getAttributeList(dataset, rootCollection.name);
const { requestMessage, ...codapResponse } = attributeListRes;
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(requestMessage) });
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(codapResponse) });
return { tool_call_id: toolCall.id, output: JSON.stringify(attributeListRes) };
if (toolCall.function.name === "create_request") {
const { action, resource, values } = JSON.parse(toolCall.function.arguments);
const request = { action, resource, values };
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(request) });
const res = yield codapInterface.sendRequest(request);
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(res) });
return { tool_call_id: toolCall.id, output: JSON.stringify(res) };
} else {
const { dataset, name, xAttribute, yAttribute } = JSON.parse(toolCall.function.arguments);
const { requestMessage, ...codapResponse} = yield createGraph(dataset, name, xAttribute, yAttribute);
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Request sent to CODAP", content: formatMessage(requestMessage) });
self.transcriptStore.addMessage(DEBUG_SPEAKER, { description: "Response from CODAP", content: formatMessage(codapResponse) });
return { tool_call_id: toolCall.id, output: "Graph created." };
return { tool_call_id: toolCall.id, output: "Tool call not recognized." };
}
})
))
: [];

self.transcriptStore.addMessage(DEBUG_SPEAKER, {description: "Tool outputs", content: formatMessage(toolOutputs)});
if (toolOutputs) {
yield self.apiConnection.beta.threads.runs.submitToolOutputs(
self.thread.id, runId, { tool_outputs: toolOutputs }
);

}
} catch (err) {
console.error(err);
Expand Down
2 changes: 1 addition & 1 deletion src/models/chat-transcript-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { MessageContent } from "../types";

const MessageModel = types.model("MessageModel", {
speaker: types.string,
messageContent: types.frozen(),
messageContent: types.frozen<MessageContent>(),
timestamp: types.string,
});

Expand Down
55 changes: 14 additions & 41 deletions src/utils/openai-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,60 +15,33 @@ export const openAiTools: AssistantTool[] = [
{
type: "function",
function: {
name: "get_attributes",
description: "Get a list of all attributes in a dataset",
strict: true,
name: "create_request",
description: "Create a request to get data from CODAP",
strict: false,
parameters: {
type: "object",
properties: {
dataset: {
action: {
type: "string",
description: "The specified dataset containing attributes"
}
},
additionalProperties: false,
required: [
"dataset"
]
}
}
},
{
type: "function",
function: {
name: "create_graph",
description: "Create a graph tile in CODAP",
strict: true,
parameters: {
type: "object",
properties: {
dataset: {
type: "string",
description: "The name of the dataset to which the attributes belong"
},
name: {
type: "string",
description: "A name for the graph"
},
xAttribute: {
type: "string",
description: "The x-axis attribute"
description: "The action to perform"
},
yAttribute: {
resource: {
type: "string",
description: "The y-axis attribute"
description: "The resource to act upon"
},
values: {
type: "object",
description: "The values to pass to the action"
}
},
additionalProperties: false,
required: [
"dataset",
"name",
"xAttribute",
"yAttribute"
"action",
"resource"
]
}
}
}
},
];

export const requestThreadDeletion = async (threadId: string): Promise<Response> => {
Expand Down

0 comments on commit 3391f6d

Please sign in to comment.