Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enhance HuggingFace Embeddings Inference component #3758

Merged
merged 6 commits into from
Sep 12, 2024
Original file line number Diff line number Diff line change
@@ -1,32 +1,84 @@
from urllib.parse import urlparse

import requests
from langchain_community.embeddings.huggingface import HuggingFaceInferenceAPIEmbeddings
from pydantic.v1.types import SecretStr

from langflow.base.models.model import LCModelComponent
from langflow.base.embeddings.model import LCEmbeddingsModel
from langflow.field_typing import Embeddings
from langflow.io import MessageTextInput, Output, SecretStrInput


class HuggingFaceInferenceAPIEmbeddingsComponent(LCModelComponent):
display_name = "HuggingFace Embeddings"
description = "Generate embeddings using Hugging Face Inference API models."
documentation = "https://github.com/huggingface/text-embeddings-inference"
class HuggingFaceInferenceAPIEmbeddingsComponent(LCEmbeddingsModel):
display_name = "HuggingFace Embeddings Inference"
description = "Generate embeddings using HuggingFace Text Embeddings Inference (TEI)"
documentation = "https://huggingface.co/docs/text-embeddings-inference/index"
icon = "HuggingFace"
name = "HuggingFaceInferenceAPIEmbeddings"

inputs = [
SecretStrInput(name="api_key", display_name="API Key"),
MessageTextInput(name="api_url", display_name="API URL", advanced=True, value="http://localhost:8080"),
MessageTextInput(name="model_name", display_name="Model Name", value="BAAI/bge-large-en-v1.5"),
SecretStrInput(
name="api_key",
display_name="API Key",
advanced=True,
info="Required for non-local inference endpoints. Local inference does not require an API Key.",
),
MessageTextInput(
name="inference_endpoint",
display_name="Inference Endpoint",
required=True,
value="http://localhost:8080",
info="Custom inference endpoint URL.",
),
MessageTextInput(
name="model_name",
display_name="Model Name",
value="BAAI/bge-large-en-v1.5",
info="The name of the model to use for text embeddings.",
),
]

outputs = [
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]

def validate_inference_endpoint(self, inference_endpoint: str) -> bool:
parsed_url = urlparse(inference_endpoint)
if not all([parsed_url.scheme, parsed_url.netloc]):
raise ValueError(
f"Invalid inference endpoint format: '{self.inference_endpoint}'. Please ensure the URL includes both a scheme (e.g., 'http://' or 'https://') and a domain name. Example: 'http://localhost:8080' or 'https://api.example.com'"
)

try:
response = requests.get(f"{inference_endpoint}/health", timeout=5)
except requests.RequestException:
raise ValueError(
f"Inference endpoint '{inference_endpoint}' is not responding. Please ensure the URL is correct and the service is running."
)

if response.status_code != 200:
raise ValueError(f"HuggingFace health check failed: {response.status_code}")
# returning True to solve linting error
return True

def build_embeddings(self) -> Embeddings:
if not self.api_key:
raise ValueError("API Key is required")
if not self.inference_endpoint:
raise ValueError("Inference endpoint is required")

self.validate_inference_endpoint(self.inference_endpoint)

# Check if the inference endpoint is local
is_local_url = self.inference_endpoint.startswith(("http://localhost", "http://127.0.0.1"))

api_key = SecretStr(self.api_key)
# Use a dummy key for local URLs if no key is provided.
# Refer https://python.langchain.com/v0.2/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceInferenceAPIEmbeddings.html
if not self.api_key and is_local_url:
api_key = SecretStr("DummyAPIKeyForLocalDeployment")
edwinjosechittilappilly marked this conversation as resolved.
Show resolved Hide resolved
elif not self.api_key:
raise ValueError("API Key is required for non-local inference endpoints")
else:
api_key = SecretStr(self.api_key)

return HuggingFaceInferenceAPIEmbeddings(api_key=api_key, api_url=self.api_url, model_name=self.model_name)
return HuggingFaceInferenceAPIEmbeddings(
api_key=api_key, api_url=self.inference_endpoint, model_name=self.model_name
)
Loading