Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community[patch]: add detailed paragraph and example for BaichuanTextEmbeddings #22031

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions libs/community/langchain_community/embeddings/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from requests import RequestException

BAICHUAN_API_URL: str = "http://api.baichuan-ai.com/v1/embeddings"

Expand All @@ -22,11 +23,23 @@
# NOTE!! BaichuanTextEmbeddings only supports Chinese text embedding.
# Multi-language support is coming soon.
class BaichuanTextEmbeddings(BaseModel, Embeddings):
"""Baichuan Text Embedding models."""
"""Baichuan Text Embedding models.

To use, you should set the environment variable ``BAICHUAN_API_KEY`` to
your API key or pass it as a named parameter to the constructor.

Example:
.. code-block:: python

from langchain_community.embeddings import BaichuanTextEmbeddings

baichuan = BaichuanTextEmbeddings(baichuan_api_key="my-api-key")
"""

session: Any #: :meta private:
model_name: str = "Baichuan-Text-Embedding"
baichuan_api_key: Optional[SecretStr] = None
"""Automatically inferred from env var `BAICHUAN_API_KEY` if not provided."""

@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
Expand Down Expand Up @@ -65,29 +78,26 @@ def _embed(self, texts: List[str]) -> Optional[List[List[float]]]:
A list of list of floats representing the embeddings, or None if an
error occurs.
"""
try:
response = self.session.post(
BAICHUAN_API_URL, json={"input": texts, "model": self.model_name}
response = self.session.post(
BAICHUAN_API_URL, json={"input": texts, "model": self.model_name}
)
# Raise exception if response status code from 400 to 600
response.raise_for_status()
# Check if the response status code indicates success
if response.status_code == 200:
resp = response.json()
embeddings = resp.get("data", [])
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0))
# Return just the embeddings
return [result.get("embedding", []) for result in sorted_embeddings]
else:
# Log error or handle unsuccessful response appropriately
# Handle 100 <= status_code < 400, not include 200
raise RequestException(
f"Error: Received status code {response.status_code} from "
"`BaichuanEmbedding` API"
)
# Check if the response status code indicates success
if response.status_code == 200:
resp = response.json()
embeddings = resp.get("data", [])
# Sort resulting embeddings by index
sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0))
# Return just the embeddings
return [result.get("embedding", []) for result in sorted_embeddings]
else:
# Log error or handle unsuccessful response appropriately
print( # noqa: T201
f"Error: Received status code {response.status_code} from "
"embedding API"
)
return None
except Exception as e:
# Log the exception or handle it as needed
print(f"Exception occurred while trying to get embeddings: {str(e)}") # noqa: T201
return None

def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents.
Expand Down
Loading