Skip to content

Commit

Permalink
feat: Enhance HuggingFaceEndpointsComponent with additional parameters (
Browse files Browse the repository at this point in the history
#3846)

* Update HuggingFaceInferenceAPIEmbeddings.py

update to use inference api from hugging face

* Enhance HuggingFaceEndpointsComponent with additional parameters

- Add FloatInput for top_p, typical_p, temperature, and repetition_penalty
- Update create_huggingface_endpoint and build_model methods to include new parameters
- Set default values and info for new inputs

ToDo Need to update the Package
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint

since its depreciated.

* Updated HuggingFaceModel Solving Lint Error

Updated HuggingFaceModel Solving Lint Error

* Update HuggingFaceModel.py

Added Inference Endpoint as an input from user to support custom inference endpoints

* Update HuggingFaceModel.py

paper references removed
  • Loading branch information
edwinjosechittilappilly authored Sep 21, 2024
1 parent aa25783 commit 1caba1c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from urllib.parse import urlparse
from tenacity import retry, stop_after_attempt, wait_fixed

import requests
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
Expand Down Expand Up @@ -27,7 +28,7 @@ class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
name="inference_endpoint",
display_name="Inference Endpoint",
required=True,
value="http://localhost:8080",
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
MessageTextInput(
Expand Down Expand Up @@ -61,24 +62,32 @@ def validate_inference_endpoint(self, inference_endpoint: str) -> bool:
# returning True to solve linting error
return True

def build_embeddings(self) -> Embeddings:
if not self.inference_endpoint:
raise ValueError("Inference endpoint is required")
def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_name}"
else:
return self.inference_endpoint

self.validate_inference_endpoint(self.inference_endpoint)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
def create_huggingface_embeddings(
self, api_key: SecretStr, api_url: str, model_name: str
) -> HuggingFaceInferenceAPIEmbeddings:
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=api_url, model_name=model_name)

# Check if the inference endpoint is local
is_local_url = self.inference_endpoint.startswith(("http://localhost", "http://127.0.0.1"))
def build_embeddings(self) -> Embeddings:
api_url = self.get_api_url()

is_local_url = api_url.startswith(("http://localhost", "http://127.0.0.1"))

# Use a dummy key for local URLs if no key is provided.
# Refer https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInferenceAPIEmbeddings.html
if not self.api_key and is_local_url:
self.validate_inference_endpoint(api_url)
api_key = SecretStr("DummyAPIKeyForLocalDeployment")
elif not self.api_key:
raise ValueError("API Key is required for non-local inference endpoints")
else:
api_key = SecretStr(self.api_key)

return HuggingFaceInferenceAPIEmbeddings(
api_key=api_key, api_url=self.inference_endpoint, model_name=self.model_name
)
try:
return self.create_huggingface_embeddings(api_key, api_url, self.model_name)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Inference API.") from e
100 changes: 92 additions & 8 deletions src/backend/base/langflow/components/models/HuggingFaceModel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from tenacity import retry, stop_after_attempt, wait_fixed
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint

# TODO: langchain_community.llms.huggingface_endpoint is depreciated. Need to update to langchain_huggingface, but have dependency with langchain_core 0.3.0
from langflow.base.models.model import LCModelComponent
from langflow.field_typing import LanguageModel
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput
from langflow.io import DictInput, DropdownInput, SecretStrInput, StrInput, IntInput, FloatInput
from typing import Any, Dict, Optional


class HuggingFaceEndpointsComponent(LCModelComponent):
Expand All @@ -18,22 +20,81 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
display_name="Model ID",
value="openai-community/gpt2",
),
StrInput(
name="inference_endpoint",
display_name="Inference Endpoint",
value="https://api-inference.huggingface.co/models/",
info="Custom inference endpoint URL.",
),
DropdownInput(
name="task",
display_name="Task",
options=["text2text-generation", "text-generation", "summarization", "translation"],
value="text-generation",
advanced=True,
info="The task to call the model with. Should be a task that returns `generated_text` or `summary_text`.",
),
SecretStrInput(name="huggingfacehub_api_token", display_name="API Token", password=True),
DictInput(name="model_kwargs", display_name="Model Keyword Arguments", advanced=True),
IntInput(name="retry_attempts", display_name="Retry Attempts", value=1, advanced=True),
IntInput(
name="max_new_tokens", display_name="Max New Tokens", value=512, info="Maximum number of generated tokens"
),
IntInput(
name="top_k",
display_name="Top K",
advanced=True,
info="The number of highest probability vocabulary tokens to keep for top-k-filtering",
),
FloatInput(
name="top_p",
display_name="Top P",
value=0.95,
advanced=True,
info="If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation",
),
FloatInput(
name="typical_p",
display_name="Typical P",
value=0.95,
advanced=True,
info="Typical Decoding mass.",
),
FloatInput(
name="temperature",
display_name="Temperature",
value=0.8,
advanced=True,
info="The value used to module the logits distribution",
),
FloatInput(
name="repetition_penalty",
display_name="Repetition Penalty",
advanced=True,
info="The parameter for repetition penalty. 1.0 means no penalty.",
),
]

def get_api_url(self) -> str:
if "huggingface" in self.inference_endpoint.lower():
return f"{self.inference_endpoint}{self.model_id}"
else:
return self.inference_endpoint

def create_huggingface_endpoint(
self, model_id: str, task: str, huggingfacehub_api_token: str, model_kwargs: dict
self,
model_id: str,
task: Optional[str],
huggingfacehub_api_token: Optional[str],
model_kwargs: Dict[str, Any],
max_new_tokens: int,
top_k: Optional[int],
top_p: float,
typical_p: Optional[float],
temperature: Optional[float],
repetition_penalty: Optional[float],
) -> HuggingFaceEndpoint:
retry_attempts = self.retry_attempts # Access the retry attempts input
endpoint_url = f"https://api-inference.huggingface.co/models/{model_id}"
retry_attempts = self.retry_attempts
endpoint_url = self.get_api_url()

@retry(stop=stop_after_attempt(retry_attempts), wait=wait_fixed(2))
def _attempt_create():
Expand All @@ -42,18 +103,41 @@ def _attempt_create():
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=self.typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)

return _attempt_create()

def build_model(self) -> LanguageModel: # type: ignore[type-var]
def build_model(self) -> LanguageModel:
model_id = self.model_id
task = self.task
task = self.task or None
huggingfacehub_api_token = self.huggingfacehub_api_token
model_kwargs = self.model_kwargs or {}
max_new_tokens = self.max_new_tokens
top_k = self.top_k or None
top_p = self.top_p
typical_p = self.typical_p or None
temperature = self.temperature or 0.8
repetition_penalty = self.repetition_penalty or None

try:
llm = self.create_huggingface_endpoint(model_id, task, huggingfacehub_api_token, model_kwargs)
llm = self.create_huggingface_endpoint(
model_id=model_id,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e

Expand Down

0 comments on commit 1caba1c

Please sign in to comment.