Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/HF custom endpoint #2811

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ChatHuggingFace_ChatModels implements INode {
constructor() {
this.label = 'ChatHuggingFace'
this.name = 'chatHuggingFace'
this.version = 2.0
this.version = 3.0
this.type = 'ChatHuggingFace'
this.icon = 'HuggingFace.svg'
this.category = 'Chat Models'
Expand Down Expand Up @@ -96,6 +96,16 @@ class ChatHuggingFace_ChatModels implements INode {
description: 'Frequency Penalty parameter may not apply to certain model. Please check available model parameters',
optional: true,
additionalParams: true
},
{
label: 'Stop Sequence',
name: 'stop',
type: 'string',
rows: 4,
placeholder: 'AI assistant:',
description: 'Sets the stop sequences to use. Use comma to seperate different sequences.',
optional: true,
additionalParams: true
}
]
}
Expand All @@ -109,6 +119,7 @@ class ChatHuggingFace_ChatModels implements INode {
const frequencyPenalty = nodeData.inputs?.frequencyPenalty as string
const endpoint = nodeData.inputs?.endpoint as string
const cache = nodeData.inputs?.cache as BaseCache
const stop = nodeData.inputs?.stop as string

const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const huggingFaceApiKey = getCredentialParam('huggingFaceApiKey', credentialData, nodeData)
Expand All @@ -123,7 +134,11 @@ class ChatHuggingFace_ChatModels implements INode {
if (topP) obj.topP = parseFloat(topP)
if (hfTopK) obj.topK = parseFloat(hfTopK)
if (frequencyPenalty) obj.frequencyPenalty = parseFloat(frequencyPenalty)
if (endpoint) obj.endpoint = endpoint
if (endpoint) obj.endpointUrl = endpoint
if (stop) {
const stopSequences = stop.split(',')
obj.stopSequences = stopSequences
}

const huggingFace = new HuggingFaceInference(obj)
if (cache) huggingFace.cache = cache
Expand Down
91 changes: 59 additions & 32 deletions packages/components/nodes/chatmodels/ChatHuggingFace/core.ts
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
import { LLM, BaseLLMParams } from '@langchain/core/language_models/llms'
import { getEnvironmentVariable } from '../../../src/utils'
import { GenerationChunk } from '@langchain/core/outputs'
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'

export interface HFInput {
/** Model to use */
model: string

/** Sampling temperature to use */
temperature?: number

/**
* Maximum number of tokens to generate in the completion.
*/
maxTokens?: number

/** Total probability mass of tokens to consider at each step */
stopSequences?: string[]
topP?: number

/** Integer to define the top tokens considered within the sample operation to create new text. */
topK?: number

/** Penalizes repeated tokens according to frequency */
frequencyPenalty?: number

/** API key to use. */
apiKey?: string

/** Private endpoint to use. */
endpoint?: string
endpointUrl?: string
includeCredentials?: string | boolean
}

export class HuggingFaceInference extends LLM implements HFInput {
Expand All @@ -40,6 +27,8 @@ export class HuggingFaceInference extends LLM implements HFInput {

temperature: number | undefined = undefined

stopSequences: string[] | undefined = undefined

maxTokens: number | undefined = undefined

topP: number | undefined = undefined
Expand All @@ -50,19 +39,23 @@ export class HuggingFaceInference extends LLM implements HFInput {

apiKey: string | undefined = undefined

endpoint: string | undefined = undefined
endpointUrl: string | undefined = undefined

includeCredentials: string | boolean | undefined = undefined

constructor(fields?: Partial<HFInput> & BaseLLMParams) {
super(fields ?? {})

this.model = fields?.model ?? this.model
this.temperature = fields?.temperature ?? this.temperature
this.maxTokens = fields?.maxTokens ?? this.maxTokens
this.stopSequences = fields?.stopSequences ?? this.stopSequences
this.topP = fields?.topP ?? this.topP
this.topK = fields?.topK ?? this.topK
this.frequencyPenalty = fields?.frequencyPenalty ?? this.frequencyPenalty
this.endpoint = fields?.endpoint ?? ''
this.apiKey = fields?.apiKey ?? getEnvironmentVariable('HUGGINGFACEHUB_API_KEY')
this.endpointUrl = fields?.endpointUrl
this.includeCredentials = fields?.includeCredentials
if (!this.apiKey) {
throw new Error(
'Please set an API key for HuggingFace Hub in the environment variable HUGGINGFACEHUB_API_KEY or in the apiKey field of the HuggingFaceInference constructor.'
Expand All @@ -74,31 +67,65 @@ export class HuggingFaceInference extends LLM implements HFInput {
return 'hf'
}

/** @ignore */
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
const { HfInference } = await HuggingFaceInference.imports()
const hf = new HfInference(this.apiKey)
const obj: any = {
invocationParams(options?: this['ParsedCallOptions']) {
return {
model: this.model,
parameters: {
// make it behave similar to openai, returning only the generated text
return_full_text: false,
temperature: this.temperature,
max_new_tokens: this.maxTokens,
stop: options?.stop ?? this.stopSequences,
top_p: this.topP,
top_k: this.topK,
repetition_penalty: this.frequencyPenalty
},
inputs: prompt
}
}
if (this.endpoint) {
hf.endpoint(this.endpoint)
} else {
obj.model = this.model
}

async *_streamResponseChunks(
prompt: string,
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<GenerationChunk> {
const hfi = await this._prepareHFInference()
const stream = await this.caller.call(async () =>
hfi.textGenerationStream({
...this.invocationParams(options),
inputs: prompt
})
)
for await (const chunk of stream) {
const token = chunk.token.text
yield new GenerationChunk({ text: token, generationInfo: chunk })
await runManager?.handleLLMNewToken(token ?? '')

// stream is done
if (chunk.generated_text)
yield new GenerationChunk({
text: '',
generationInfo: { finished: true }
})
}
const res = await this.caller.callWithOptions({ signal: options.signal }, hf.textGeneration.bind(hf), obj)
}

/** @ignore */
async _call(prompt: string, options: this['ParsedCallOptions']): Promise<string> {
const hfi = await this._prepareHFInference()
const args = { ...this.invocationParams(options), inputs: prompt }
const res = await this.caller.callWithOptions({ signal: options.signal }, hfi.textGeneration.bind(hfi), args)
return res.generated_text
}

/** @ignore */
private async _prepareHFInference() {
const { HfInference } = await HuggingFaceInference.imports()
const hfi = new HfInference(this.apiKey, {
includeCredentials: this.includeCredentials
})
return this.endpointUrl ? hfi.endpoint(this.endpointUrl) : hfi
}

/** @ignore */
static async imports(): Promise<{
HfInference: typeof import('@huggingface/inference').HfInference
Expand Down
Loading