-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Text Classifier Node): Add Text Classifier Node (#9997)
Co-authored-by: oleg <me@olegivaniv.com>
- Loading branch information
1 parent
4a3b97c
commit 28ca7d6
Showing
3 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
223 changes: 223 additions & 0 deletions
223
packages/@n8n/nodes-langchain/nodes/chains/TextClassifier/TextClassifier.node.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import type { | ||
IDataObject, | ||
IExecuteFunctions, | ||
INodeExecutionData, | ||
INodeParameters, | ||
INodeType, | ||
INodeTypeDescription, | ||
} from 'n8n-workflow'; | ||
|
||
import { NodeConnectionType } from 'n8n-workflow'; | ||
|
||
import type { BaseLanguageModel } from '@langchain/core/language_models/base'; | ||
import { HumanMessage } from '@langchain/core/messages'; | ||
import { SystemMessagePromptTemplate, ChatPromptTemplate } from '@langchain/core/prompts'; | ||
import { StructuredOutputParser } from 'langchain/output_parsers'; | ||
import { z } from 'zod'; | ||
import { getTracingConfig } from '../../../utils/tracing'; | ||
|
||
const SYSTEM_PROMPT_TEMPLATE = | ||
"Please classify the text provided by the user into one of the following categories: {categories}, and use the provided formatting instructions below. Don't explain, and only output the json."; | ||
|
||
const configuredOutputs = (parameters: INodeParameters) => { | ||
const categories = ((parameters.categories as IDataObject)?.categories as IDataObject[]) ?? []; | ||
const fallback = (parameters.options as IDataObject)?.fallback as boolean; | ||
const ret = categories.map((cat) => { | ||
return { type: NodeConnectionType.Main, displayName: cat.category }; | ||
}); | ||
if (fallback) ret.push({ type: NodeConnectionType.Main, displayName: 'Other' }); | ||
return ret; | ||
}; | ||
|
||
export class TextClassifier implements INodeType { | ||
description: INodeTypeDescription = { | ||
displayName: 'Text Classifier', | ||
name: 'textClassifier', | ||
icon: 'fa:tags', | ||
group: ['transform'], | ||
version: 1, | ||
description: 'Classify your text into distinct categories', | ||
codex: { | ||
categories: ['AI'], | ||
subcategories: { | ||
AI: ['Chains', 'Root Nodes'], | ||
}, | ||
resources: { | ||
primaryDocumentation: [ | ||
{ | ||
url: 'https://docs.n8n.io/integrations/builtin/cluster-nodes/root-nodes/n8n-nodes-langchain.chainllm/', | ||
}, | ||
], | ||
}, | ||
}, | ||
defaults: { | ||
name: 'Text Classifier', | ||
}, | ||
inputs: [ | ||
{ displayName: '', type: NodeConnectionType.Main }, | ||
{ | ||
displayName: 'Model', | ||
maxConnections: 1, | ||
type: NodeConnectionType.AiLanguageModel, | ||
required: true, | ||
}, | ||
], | ||
outputs: `={{(${configuredOutputs})($parameter)}}`, | ||
properties: [ | ||
{ | ||
displayName: 'Text to Classify', | ||
name: 'inputText', | ||
type: 'string', | ||
required: true, | ||
default: '', | ||
description: 'Use an expression to reference data in previous nodes or enter static text', | ||
typeOptions: { | ||
rows: 2, | ||
}, | ||
}, | ||
{ | ||
displayName: 'Categories', | ||
name: 'categories', | ||
placeholder: 'Add Category', | ||
type: 'fixedCollection', | ||
default: {}, | ||
typeOptions: { | ||
multipleValues: true, | ||
}, | ||
options: [ | ||
{ | ||
name: 'categories', | ||
displayName: 'Categories', | ||
values: [ | ||
{ | ||
displayName: 'Category', | ||
name: 'category', | ||
type: 'string', | ||
default: '', | ||
description: 'Category to add', | ||
required: true, | ||
}, | ||
{ | ||
displayName: 'Description', | ||
name: 'description', | ||
type: 'string', | ||
default: '', | ||
description: "Describe your category if it's not obvious", | ||
}, | ||
], | ||
}, | ||
], | ||
}, | ||
{ | ||
displayName: 'Options', | ||
name: 'options', | ||
type: 'collection', | ||
default: {}, | ||
placeholder: 'Add Option', | ||
options: [ | ||
{ | ||
displayName: 'Allow Multiple Classes To Be True', | ||
name: 'multiClass', | ||
type: 'boolean', | ||
default: false, | ||
}, | ||
{ | ||
displayName: 'Add Fallback Option', | ||
name: 'fallback', | ||
type: 'boolean', | ||
default: false, | ||
description: 'Whether to add a "fallback" option if no other categories match', | ||
}, | ||
{ | ||
displayName: 'System Prompt Template', | ||
name: 'systemPromptTemplate', | ||
type: 'string', | ||
default: SYSTEM_PROMPT_TEMPLATE, | ||
description: 'String to use directly as the system prompt template', | ||
typeOptions: { | ||
rows: 6, | ||
}, | ||
}, | ||
], | ||
}, | ||
], | ||
}; | ||
|
||
async execute(this: IExecuteFunctions): Promise<INodeExecutionData[][]> { | ||
const items = this.getInputData(); | ||
|
||
const llm = (await this.getInputConnectionData( | ||
NodeConnectionType.AiLanguageModel, | ||
0, | ||
)) as BaseLanguageModel; | ||
|
||
const categories = this.getNodeParameter('categories.categories', 0) as Array<{ | ||
category: string; | ||
description: string; | ||
}>; | ||
|
||
const options = this.getNodeParameter('options', 0, {}) as { | ||
multiClass: boolean; | ||
fallback: boolean; | ||
systemPromptTemplate?: string; | ||
}; | ||
const multiClass = options?.multiClass ?? false; | ||
const fallback = options?.fallback ?? false; | ||
|
||
const schemaEntries = categories.map((cat) => [ | ||
cat.category, | ||
z | ||
.boolean() | ||
.describe( | ||
`Should be true if the input has category "${cat.category}" (description: ${cat.description})`, | ||
), | ||
]); | ||
if (fallback) | ||
schemaEntries.push([ | ||
'fallback', | ||
z.boolean().describe('Should be true if none of the other categories apply'), | ||
]); | ||
const schema = z.object(Object.fromEntries(schemaEntries)); | ||
|
||
const parser = StructuredOutputParser.fromZodSchema(schema); | ||
|
||
const multiClassPrompt = multiClass | ||
? 'Categories are not mutually exclusive, and multiple can be true' | ||
: 'Categories are mutually exclusive, and only one can be true'; | ||
const fallbackPrompt = fallback | ||
? 'If no categories apply, select the "fallback" option.' | ||
: 'One of the options must always be true.'; | ||
|
||
const systemPromptTemplate = SystemMessagePromptTemplate.fromTemplate( | ||
`${options.systemPromptTemplate ?? SYSTEM_PROMPT_TEMPLATE} | ||
{format_instructions} | ||
${multiClassPrompt} | ||
${fallbackPrompt}`, | ||
); | ||
|
||
const returnData: INodeExecutionData[][] = Array.from( | ||
{ length: categories.length + (fallback ? 1 : 0) }, | ||
(_) => [], | ||
); | ||
for (let itemIdx = 0; itemIdx < items.length; itemIdx++) { | ||
const input = this.getNodeParameter('inputText', itemIdx) as string; | ||
const inputPrompt = new HumanMessage(input); | ||
const messages = [ | ||
await systemPromptTemplate.format({ | ||
categories: categories.map((cat) => cat.category).join(', '), | ||
format_instructions: parser.getFormatInstructions(), | ||
}), | ||
inputPrompt, | ||
]; | ||
const prompt = ChatPromptTemplate.fromMessages(messages); | ||
const chain = prompt.pipe(llm).pipe(parser).withConfig(getTracingConfig(this)); | ||
|
||
const output = await chain.invoke(messages); | ||
categories.forEach((cat, idx) => { | ||
if (output[cat.category]) returnData[idx].push(items[itemIdx]); | ||
}); | ||
if (fallback && output.fallback) returnData[returnData.length - 1].push(items[itemIdx]); | ||
} | ||
return returnData; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters