Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance HuggingFaceEndpointsComponent with additional parameters #3846

Merged
merged 12 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from urllib.parse import urlparse
from tenacity import retry, stop_after_attempt, wait_fixed

import requests
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
Expand Down Expand Up @@ -27,7 +28,7 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
name="inference_endpoint",
display_name="Inference Endpoint",
required=True,
value="http://localhost:8080",
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
MessageTextInput(
Expand Down Expand Up @@ -61,24 +62,32 @@ def validate_inference_endpoint(self, inference_endpoint: str) -> bool:
# returning True to solve linting error
return True

def build_embeddings(self) -> Embeddings:
if not self.inference_endpoint:
raise ValueError("Inference endpoint is required")
def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_name}"
else:
return self.inference_endpoint

self.validate_inference_endpoint(self.inference_endpoint)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def create_huggingface_embeddings(
self, api_key: SecretStr, api_url: str, model_name: str
) -> HuggingFaceInferenceAPIEmbeddings:
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name)

# Check if the inference endpoint is local
is_local_url = self.inference_endpoint.startswith(("http://localhost", "http://127.0.0.1"))
def build_embeddings(self) -> Embeddings:
api_url = self.get_api_url()

is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1"))

# Use a dummy key for local URLs if no key is provided.
# Refer https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInferenceAPIEmbeddings.html
if not self.api_key and is_local_url:
self.validate_inference_endpoint(api_url)
api_key = SecretStr("DummyAPIKeyForLocalDeployment")
elif not self.api_key:
raise ValueError("API Key is required for non-local inference endpoints")
else:
api_key = SecretStr(self.api_key)

return HuggingFaceInferenceAPIEmbeddings(
api_key=api_key, api_url=self.inference_endpoint, model_name=self.model_name
)
try:
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Inference API.") from e
100 changes: 92 additions & 8 deletions src/backend/base/langflow/components/models/HuggingFaceModel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from tenacity import retry, stop_after_attempt, wait_fixed
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint

# TODO: langchain_community.llms.huggingface_endpoint is depreciated. Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput, FloatInput
from typing import Any, Dict, Optional


class HuggingFaceEndpointsComponent(LCModelComponent):
Expand All @@ -18,22 +20,81 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
display_name="Model ID",
value="openai-community/gpt2",
),
StrInput(
name="inference_endpoint",
display_name="Inference Endpoint",
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
DropdownInput(
name="task",
display_name="Task",
options=["text2text-generation", "text-generation", "summarization", "translation"],
value="text-generation",
advanced=True,
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.",
),
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True),
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True),
IntInput(
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
),
IntInput(
name="top_k",
display_name="Top K",
advanced=True,
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
info="If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation",
),
FloatInput(
name="typical_p",
display_name="Typical P",
value=0.95,
advanced=True,
info="Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information",
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.8,
advanced=True,
info="The value used to module the logits distribution",
),
FloatInput(
name="repetition_penalty",
display_name="Repetition Penalty",
advanced=True,
info="The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details",
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved
),
]

def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_id}"
else:
return self.inference_endpoint

def create_huggingface_endpoint(
self, model_id: str, task: str, huggingfacehub_api_token: str, model_kwargs: dict
self,
model_id: str,
task: Optional[str],
huggingfacehub_api_token: Optional[str],
model_kwargs: Dict[str, Any],
max_new_tokens: int,
top_k: Optional[int],
top_p: float,
typical_p: Optional[float],
temperature: Optional[float],
repetition_penalty: Optional[float],
) -> HuggingFaceEndpoint:
retry_attempts = self.retry_attempts # Access the retry attempts input
endpoint_url = f"https://api-inference.huggingface.co/models/{model_id}"
retry_attempts = self.retry_attempts
endpoint_url = self.get_api_url()

@retry(stop=stop_after_attempt(retry_attempts), wait=wait_fixed(2))
def _attempt_create():
Expand All @@ -42,18 +103,41 @@ def _attempt_create():
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=self.typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)

return _attempt_create()

def build_model(self) -> LanguageModel: # type: ignore[type-var]
def build_model(self) -> LanguageModel:
model_id = self.model_id
task = self.task
task = self.task or None
huggingfacehub_api_token = self.huggingfacehub_api_token
model_kwargs = self.model_kwargs or {}
max_new_tokens = self.max_new_tokens
top_k = self.top_k or None
top_p = self.top_p
typical_p = self.typical_p or None
temperature = self.temperature or 0.8
repetition_penalty = self.repetition_penalty or None

try:
llm = self.create_huggingface_endpoint(model_id, task, huggingfacehub_api_token, model_kwargs)
llm = self.create_huggingface_endpoint(
model_id=model_id,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e

Expand Down
Loading