Skip to content

Commit

Permalink
feat(OpenAI Node): Use v2 assistants API and add support for memory (#…
Browse files Browse the repository at this point in the history
…9406)

Signed-off-by: Oleg Ivaniv <me@olegivaniv.com>
  • Loading branch information
OlegIvaniv authored May 16, 2024
1 parent 40bce7f commit ce3eb12
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,24 @@ const properties: INodeProperties[] = [
type: 'collection',
default: {},
options: [
{
displayName: 'Output Randomness (Temperature)',
name: 'temperature',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. We generally recommend altering this or temperature but not both.',
type: 'number',
},
{
displayName: 'Output Randomness (Top P)',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'An alternative to sampling with temperature, controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
{
displayName: 'Fail if Assistant Already Exists',
name: 'failIfExists',
Expand Down Expand Up @@ -176,7 +194,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
do {
const response = (await apiRequest.call(this, 'GET', '/assistants', {
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
qs: {
limit: 100,
Expand Down Expand Up @@ -219,7 +237,6 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
name,
description: assistantDescription,
instructions,
file_ids,
};

const tools = [];
Expand All @@ -228,12 +245,28 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
tools.push({
type: 'code_interpreter',
});
body.tool_resources = {
...((body.tool_resources as object) ?? {}),
code_interpreter: {
file_ids,
},
};
}

if (knowledgeRetrieval) {
tools.push({
type: 'retrieval',
type: 'file_search',
});
body.tool_resources = {
...((body.tool_resources as object) ?? {}),
file_search: {
vector_stores: [
{
file_ids,
},
],
},
};
}

if (tools.length) {
Expand All @@ -243,7 +276,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
const response = await apiRequest.call(this, 'POST', '/assistants', {
body,
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode

const response = await apiRequest.call(this, 'DELETE', `/assistants/${assistantId}`, {
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
do {
const response = await apiRequest.call(this, 'GET', '/assistants', {
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
qs: {
limit: 100,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@ import { OpenAIAssistantRunnable } from 'langchain/experimental/openai_assistant
import type { OpenAIToolType } from 'langchain/dist/experimental/openai_assistant/schema';
import { OpenAI as OpenAIClient } from 'openai';

import { NodeOperationError, updateDisplayOptions } from 'n8n-workflow';
import type { IExecuteFunctions, INodeExecutionData, INodeProperties } from 'n8n-workflow';

import { NodeConnectionType, NodeOperationError, updateDisplayOptions } from 'n8n-workflow';
import type {
IDataObject,
IExecuteFunctions,
INodeExecutionData,
INodeProperties,
} from 'n8n-workflow';

import type { BufferWindowMemory } from 'langchain/memory';
import omit from 'lodash/omit';
import type { BaseMessage } from '@langchain/core/messages';
import { formatToOpenAIAssistantTool } from '../../helpers/utils';
import { assistantRLC } from '../descriptions';

Expand Down Expand Up @@ -110,6 +118,12 @@ const displayOptions = {
};

export const description = updateDisplayOptions(displayOptions, properties);
const mapChatMessageToThreadMessage = (
message: BaseMessage,
): OpenAIClient.Beta.Threads.ThreadCreateParams.Message => ({
role: message._getType() === 'ai' ? 'assistant' : 'user',
content: message.content.toString(),
});

export async function execute(this: IExecuteFunctions, i: number): Promise<INodeExecutionData[]> {
const credentials = await this.getCredentials('openAiApi');
Expand Down Expand Up @@ -182,11 +196,47 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
tools: tools ?? [],
});

const response = await agentExecutor.withConfig(getTracingConfig(this)).invoke({
const memory = (await this.getInputConnectionData(NodeConnectionType.AiMemory, 0)) as
| BufferWindowMemory
| undefined;

const chainValues: IDataObject = {
content: input,
signal: this.getExecutionCancelSignal(),
timeout: options.timeout ?? 10000,
});
};
let thread: OpenAIClient.Beta.Threads.Thread;
if (memory) {
const chatMessages = await memory.chatHistory.getMessages();

// Construct a new thread from the chat history to map the memory
if (chatMessages.length) {
const first32Messages = chatMessages.slice(0, 32);
// There is a undocumented limit of 32 messages per thread when creating a thread with messages
const mappedMessages: OpenAIClient.Beta.Threads.ThreadCreateParams.Message[] =
first32Messages.map(mapChatMessageToThreadMessage);

thread = await client.beta.threads.create({ messages: mappedMessages });
const overLimitMessages = chatMessages.slice(32).map(mapChatMessageToThreadMessage);

// Send the remaining messages that exceed the limit of 32 sequentially
for (const message of overLimitMessages) {
await client.beta.threads.messages.create(thread.id, message);
}

chainValues.threadId = thread.id;
}
}

const response = await agentExecutor.withConfig(getTracingConfig(this)).invoke(chainValues);
if (memory) {
await memory.saveContext({ input }, { output: response.output });

if (response.threadId && response.runId) {
const threadRun = await client.beta.threads.runs.retrieve(response.threadId, response.runId);
response.usage = threadRun.usage;
}
}

if (
options.preserveOriginalTools !== false &&
Expand All @@ -197,6 +247,6 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
tools: assistantTools,
});
}

return [{ json: response, pairedItem: { item: i } }];
const filteredResponse = omit(response, ['signal', 'timeout']);
return [{ json: filteredResponse, pairedItem: { item: i } }];
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,25 @@ const properties: INodeProperties[] = [
default: false,
description: 'Whether to remove all custom tools (functions) from the assistant',
},

{
displayName: 'Output Randomness (Temperature)',
name: 'temperature',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'Controls randomness: Lowering results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. We generally recommend altering this or temperature but not both.',
type: 'number',
},
{
displayName: 'Output Randomness (Top P)',
name: 'topP',
default: 1,
typeOptions: { maxValue: 1, minValue: 0, numberPrecision: 1 },
description:
'An alternative to sampling with temperature, controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered. We generally recommend altering this or temperature but not both.',
type: 'number',
},
],
},
];
Expand All @@ -109,6 +128,8 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
knowledgeRetrieval,
file_ids,
removeCustomTools,
temperature,
topP,
} = options;

const assistantDescription = options.description as string;
Expand All @@ -128,7 +149,19 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
);
}

body.file_ids = files;
body.tool_resources = {
...((body.tool_resources as object) ?? {}),
code_interpreter: {
file_ids,
},
file_search: {
vector_stores: [
{
file_ids,
},
],
},
};
}

if (modelId) {
Expand All @@ -147,11 +180,19 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
body.instructions = instructions;
}

if (temperature) {
body.temperature = temperature;
}

if (topP) {
body.topP = topP;
}

let tools =
((
await apiRequest.call(this, 'GET', `/assistants/${assistantId}`, {
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
})
).tools as IDataObject[]) || [];
Expand All @@ -166,14 +207,14 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
tools = tools.filter((tool) => tool.type !== 'code_interpreter');
}

if (knowledgeRetrieval && !tools.find((tool) => tool.type === 'retrieval')) {
if (knowledgeRetrieval && !tools.find((tool) => tool.type === 'file_search')) {
tools.push({
type: 'retrieval',
type: 'file_search',
});
}

if (knowledgeRetrieval === false && tools.find((tool) => tool.type === 'retrieval')) {
tools = tools.filter((tool) => tool.type !== 'retrieval');
if (knowledgeRetrieval === false && tools.find((tool) => tool.type === 'file_search')) {
tools = tools.filter((tool) => tool.type !== 'file_search');
}

if (removeCustomTools) {
Expand All @@ -185,7 +226,7 @@ export async function execute(this: IExecuteFunctions, i: number): Promise<INode
const response = await apiRequest.call(this, 'POST', `/assistants/${assistantId}`, {
body,
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const configureNodeInputs = (resource: string, operation: string, hideTools: str
if (resource === 'assistant' && operation === 'message') {
return [
{ type: NodeConnectionType.Main },
{ type: NodeConnectionType.AiMemory, displayName: 'Memory', maxConnections: 1 },
{ type: NodeConnectionType.AiTool, displayName: 'Tools' },
];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ export async function assistantSearch(
): Promise<INodeListSearchResult> {
const { data, has_more, last_id } = await apiRequest.call(this, 'GET', '/assistants', {
headers: {
'OpenAI-Beta': 'assistants=v1',
'OpenAI-Beta': 'assistants=v2',
},
qs: {
limit: 100,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,24 @@ describe('OpenAi, Assistant resource', () => {
expect(transport.apiRequest).toHaveBeenCalledWith('POST', '/assistants', {
body: {
description: 'description',
file_ids: [],
instructions: 'some instructions',
model: 'gpt-model',
name: 'name',
tools: [{ type: 'code_interpreter' }, { type: 'retrieval' }],
tool_resources: {
code_interpreter: {
file_ids: [],
},
file_search: {
vector_stores: [
{
file_ids: [],
},
],
},
},
tools: [{ type: 'code_interpreter' }, { type: 'file_search' }],
},
headers: { 'OpenAI-Beta': 'assistants=v1' },
headers: { 'OpenAI-Beta': 'assistants=v2' },
});
});

Expand Down Expand Up @@ -124,7 +135,7 @@ describe('OpenAi, Assistant resource', () => {
);

expect(transport.apiRequest).toHaveBeenCalledWith('DELETE', '/assistants/assistant-id', {
headers: { 'OpenAI-Beta': 'assistants=v1' },
headers: { 'OpenAI-Beta': 'assistants=v2' },
});
});

Expand Down Expand Up @@ -185,17 +196,28 @@ describe('OpenAi, Assistant resource', () => {

expect(transport.apiRequest).toHaveBeenCalledTimes(2);
expect(transport.apiRequest).toHaveBeenCalledWith('GET', '/assistants/assistant-id', {
headers: { 'OpenAI-Beta': 'assistants=v1' },
headers: { 'OpenAI-Beta': 'assistants=v2' },
});
expect(transport.apiRequest).toHaveBeenCalledWith('POST', '/assistants/assistant-id', {
body: {
file_ids: [],
instructions: 'some instructions',
model: 'gpt-model',
name: 'name',
tools: [{ type: 'existing_tool' }, { type: 'code_interpreter' }, { type: 'retrieval' }],
tool_resources: {
code_interpreter: {
file_ids: [],
},
file_search: {
vector_stores: [
{
file_ids: [],
},
],
},
},
tools: [{ type: 'existing_tool' }, { type: 'code_interpreter' }, { type: 'file_search' }],
},
headers: { 'OpenAI-Beta': 'assistants=v1' },
headers: { 'OpenAI-Beta': 'assistants=v2' },
});
});
});
Expand Down

0 comments on commit ce3eb12

Please sign in to comment.