Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "huggingface[patch]: make HuggingFaceEndpoint serializable (#2… #27032

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,14 @@
import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional

from huggingface_hub import ( # type: ignore[import-untyped]
AsyncInferenceClient,
InferenceClient,
login,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.utils import get_pydantic_field_names, secret_from_env
from pydantic import ConfigDict, Field, SecretStr, model_validator
from langchain_core.utils import from_env, get_pydantic_field_names
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,12 +73,10 @@ class HuggingFaceEndpoint(LLM):
should be pass as env variable in `HF_INFERENCE_ENDPOINT`"""
repo_id: Optional[str] = None
"""Repo to use. If endpoint_url is not specified then this needs to given"""
huggingfacehub_api_token: Optional[SecretStr] = Field(
default_factory=secret_from_env(
["HUGGINGFACEHUB_API_TOKEN", "HF_TOKEN"], default=None
)
huggingfacehub_api_token: Optional[str] = Field(
default_factory=from_env("HUGGINGFACEHUB_API_TOKEN", default=None)
)
max_new_tokens: int = Field(default=512, alias="max_tokens")
max_new_tokens: int = 512
"""Maximum number of generated tokens"""
top_k: Optional[int] = None
"""The number of highest probability vocabulary tokens to keep for
Expand Down Expand Up @@ -123,15 +116,14 @@ class HuggingFaceEndpoint(LLM):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `call` not explicitly specified"""
model: str
client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
task: Optional[str] = None
"""Task to call the model with.
Should be a task that returns `generated_text` or `summary_text`."""

model_config = ConfigDict(
extra="forbid",
populate_by_name=True,
)

@model_validator(mode="before")
Expand Down Expand Up @@ -197,23 +189,36 @@ def build_extra(cls, values: Dict[str, Any]) -> Any:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that package is installed and that the API token is valid."""
if self.huggingfacehub_api_token is not None:
try:
from huggingface_hub import login # type: ignore[import]

except ImportError:
raise ImportError(
"Could not import huggingface_hub python package. "
"Please install it with `pip install huggingface_hub`."
)

huggingfacehub_api_token = self.huggingfacehub_api_token or os.getenv(
"HF_TOKEN"
)

if huggingfacehub_api_token is not None:
try:
login(token=self.huggingfacehub_api_token.get_secret_value())
login(token=huggingfacehub_api_token)
except Exception as e:
raise ValueError(
"Could not authenticate with huggingface_hub. "
"Please check your API token."
) from e

from huggingface_hub import AsyncInferenceClient, InferenceClient

# Instantiate clients with supported kwargs
sync_supported_kwargs = set(inspect.signature(InferenceClient).parameters)
self.client = InferenceClient(
model=self.model,
timeout=self.timeout,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
token=huggingfacehub_api_token,
**{
key: value
for key, value in self.server_kwargs.items()
Expand All @@ -225,9 +230,7 @@ def validate_environment(self) -> Self:
self.async_client = AsyncInferenceClient(
model=self.model,
timeout=self.timeout,
token=self.huggingfacehub_api_token.get_secret_value()
if self.huggingfacehub_api_token
else None,
token=huggingfacehub_api_token,
**{
key: value
for key, value in self.server_kwargs.items()
Expand Down Expand Up @@ -423,15 +426,3 @@ async def _astream(
# break if stop sequence found
if stop_seq_found:
break

@classmethod
def is_lc_serializable(cls) -> bool:
return True

@classmethod
def get_lc_namespace(cls) -> list[str]:
return ["langchain_huggingface", "llms"]

@property
def lc_secrets(self) -> dict[str, str]:
return {"huggingfacehub_api_token": "HUGGINGFACEHUB_API_TOKEN"}
Loading
Loading