Skip to content

Commit

Permalink
mistral: catch GatedRepoError, release 0.1.3 (#20802)
Browse files Browse the repository at this point in the history
#20618

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
2 people authored and hinthornw committed Apr 26, 2024
1 parent ddff748 commit 537f862
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 156 deletions.
28 changes: 24 additions & 4 deletions libs/partners/mistralai/langchain_mistralai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import logging
import warnings
from typing import Dict, Iterable, List, Optional

import httpx
Expand All @@ -19,6 +20,13 @@
MAX_TOKENS = 16_000


class DummyTokenizer:
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""

def encode_batch(self, texts: List[str]) -> List[List[str]]:
return [list(text) for text in texts]


class MistralAIEmbeddings(BaseModel, Embeddings):
"""MistralAI embedding models.
Expand Down Expand Up @@ -83,9 +91,18 @@ def validate_environment(cls, values: Dict) -> Dict:
timeout=values["timeout"],
)
if values["tokenizer"] is None:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
try:
values["tokenizer"] = Tokenizer.from_pretrained(
"mistralai/Mixtral-8x7B-v0.1"
)
except IOError: # huggingface_hub GatedRepoError
warnings.warn(
"Could not download mistral tokenizer from Huggingface for "
"calculating batch sizes. Set a Huggingface token via the "
"HF_TOKEN environment variable to download the real tokenizer. "
"Falling back to a dummy tokenizer that uses `len()`."
)
values["tokenizer"] = DummyTokenizer()
return values

def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:
Expand All @@ -100,7 +117,10 @@ def _get_batches(self, texts: List[str]) -> Iterable[List[str]]:

for text, text_tokens in zip(texts, text_token_lengths):
if batch_tokens + text_tokens > MAX_TOKENS:
yield batch
if len(batch) > 0:
# edge case where first batch exceeds max tokens
# should not yield an empty batch.
yield batch
batch = [text]
batch_tokens = text_tokens
else:
Expand Down
Loading

0 comments on commit 537f862

Please sign in to comment.