Skip to content

Commit

Permalink
Adress review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jonas Helming <jhelming@eclipsesource.com>
  • Loading branch information
JonasHelming committed Nov 11, 2024
1 parent 6bfaa2d commit caa370a
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions packages/ai-hugging-face/src/node/huggingface-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,26 @@ function toRoleLabel(actor: MessageActor): string {
}

export class HuggingFaceModel implements LanguageModel {
private hfInference: HfInference;

/**
* @param id the unique id for this language model. It will be used to identify the model in the UI.
* @param model the model id as it is used by the Hugging Face API
* @param apiKey function to retrieve the API key for Hugging Face
*/
constructor(public readonly id: string, public model: string, public apiKey: () => string | undefined) {
const token = this.apiKey();
if (!token) {
throw new Error('Please provide a Hugging Face API token.');
}
this.hfInference = new HfInference(token);
}

async request(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const hfInference = this.initializeHfInference();
if (this.isStreamingSupported(this.model)) {
return this.handleStreamingRequest(request, cancellationToken);
return this.handleStreamingRequest(hfInference, request, cancellationToken);
} else {
return this.handleNonStreamingRequest(request);
return this.handleNonStreamingRequest(hfInference, request);
}
}

protected async handleNonStreamingRequest(request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const response = await this.hfInference.textGeneration({
protected async handleNonStreamingRequest(hfInference: HfInference, request: LanguageModelRequest): Promise<LanguageModelTextResponse> {
const response = await hfInference.textGeneration({
model: this.model,
inputs: toHuggingFacePrompt(request.messages),
parameters: {
Expand All @@ -92,8 +87,8 @@ export class HuggingFaceModel implements LanguageModel {
};
}

protected async handleStreamingRequest(request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const stream = this.hfInference.textGenerationStream({
protected async handleStreamingRequest(hfInference: HfInference, request: LanguageModelRequest, cancellationToken?: CancellationToken): Promise<LanguageModelResponse> {
const stream = hfInference.textGenerationStream({
model: this.model,
inputs: toHuggingFacePrompt(request.messages),
parameters: {
Expand Down Expand Up @@ -123,4 +118,12 @@ export class HuggingFaceModel implements LanguageModel {
// Assuming all models support streaming for now; can be refined if needed
return true;
}

private initializeHfInference(): HfInference {
const token = this.apiKey();
if (!token) {
throw new Error('Please provide a Hugging Face API token.');
}
return new HfInference(token);
}
}

0 comments on commit caa370a

Please sign in to comment.