-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Enhance HuggingFace Embeddings Inference component (#3758)
# Enhance HuggingFaceInferenceAPIEmbeddings component ## Initial enhancements - Update display name and description for clarity - Add API URL validation method - Implement local URL detection and dummy API key for local deployments - Improve error handling for API key and URL requirements - Update documentation link ## API key info update - Updated the info in API key input ## Refactor and improvements - Update base class from LCModelComponent to LCEmbeddingsModel - Rename 'api_url' to 'inference_endpoint' for clarity - Improve error messages and validation for inference endpoint - Update documentation link - Enhance comments and code formatting
- Loading branch information
1 parent
128a0e9
commit 8275306
Showing
1 changed file
with
64 additions
and
12 deletions.
There are no files selected for viewing
76 changes: 64 additions & 12 deletions
76
src/backend/base/langflow/components/embeddings/HuggingFaceInferenceAPIEmbeddings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,84 @@ | ||
from urllib.parse import urlparse | ||
|
||
import requests | ||
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings | ||
from pydantic.v1.types import SecretStr | ||
|
||
from langflow.base.models.model import LCModelComponent | ||
from langflow.base.embeddings.model import LCEmbeddingsModel | ||
from langflow.field_typing import Embeddings | ||
from langflow.io import MessageTextInput, Output, SecretStrInput | ||
|
||
|
||
class HuggingFaceInferenceAPIEmbeddingsComponent(LCModelComponent): | ||
display_name = "HuggingFace Embeddings" | ||
description = "Generate embeddings using Hugging Face Inference API models." | ||
documentation = "https://github.com/huggingface/text-embeddings-inference" | ||
class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel): | ||
display_name = "HuggingFace Embeddings Inference" | ||
description = "Generate embeddings using HuggingFace Text Embeddings Inference (TEI)" | ||
documentation = "https://huggingface.co/docs/text-embeddings-inference/index" | ||
icon = "HuggingFace" | ||
name = "HuggingFaceInferenceAPIEmbeddings" | ||
|
||
inputs = [ | ||
SecretStrInput(name="api_key", display_name="API Key"), | ||
MessageTextInput(name="api_url", display_name="API URL", advanced=True, value="http://localhost:8080"), | ||
MessageTextInput(name="model_name", display_name="Model Name", value="BAAI/bge-large-en-v1.5"), | ||
SecretStrInput( | ||
name="api_key", | ||
display_name="API Key", | ||
advanced=True, | ||
info="Required for non-local inference endpoints. Local inference does not require an API Key.", | ||
), | ||
MessageTextInput( | ||
name="inference_endpoint", | ||
display_name="Inference Endpoint", | ||
required=True, | ||
value="http://localhost:8080", | ||
info="Custom inference endpoint URL.", | ||
), | ||
MessageTextInput( | ||
name="model_name", | ||
display_name="Model Name", | ||
value="BAAI/bge-large-en-v1.5", | ||
info="The name of the model to use for text embeddings.", | ||
), | ||
] | ||
|
||
outputs = [ | ||
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), | ||
] | ||
|
||
def validate_inference_endpoint(self, inference_endpoint: str) -> bool: | ||
parsed_url = urlparse(inference_endpoint) | ||
if not all([parsed_url.scheme, parsed_url.netloc]): | ||
raise ValueError( | ||
f"Invalid inference endpoint format: '{self.inference_endpoint}'. Please ensure the URL includes both a scheme (e.g., 'http://' or 'https://') and a domain name. Example: 'http://localhost:8080' or 'https://api.example.com'" | ||
) | ||
|
||
try: | ||
response = requests.get(f"{inference_endpoint}/health", timeout=5) | ||
except requests.RequestException: | ||
raise ValueError( | ||
f"Inference endpoint '{inference_endpoint}' is not responding. Please ensure the URL is correct and the service is running." | ||
) | ||
|
||
if response.status_code != 200: | ||
raise ValueError(f"HuggingFace health check failed: {response.status_code}") | ||
# returning True to solve linting error | ||
return True | ||
|
||
def build_embeddings(self) -> Embeddings: | ||
if not self.api_key: | ||
raise ValueError("API Key is required") | ||
if not self.inference_endpoint: | ||
raise ValueError("Inference endpoint is required") | ||
|
||
self.validate_inference_endpoint(self.inference_endpoint) | ||
|
||
# Check if the inference endpoint is local | ||
is_local_url = self.inference_endpoint.startswith(("http://localhost", "http://127.0.0.1")) | ||
|
||
api_key = SecretStr(self.api_key) | ||
# Use a dummy key for local URLs if no key is provided. | ||
# Refer https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInferenceAPIEmbeddings.html | ||
if not self.api_key and is_local_url: | ||
api_key = SecretStr("DummyAPIKeyForLocalDeployment") | ||
elif not self.api_key: | ||
raise ValueError("API Key is required for non-local inference endpoints") | ||
else: | ||
api_key = SecretStr(self.api_key) | ||
|
||
return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=self.api_url, model_name=self.model_name) | ||
return HuggingFaceInferenceAPIEmbeddings( | ||
api_key=api_key, api_url=self.inference_endpoint, model_name=self.model_name | ||
) |