diff --git a/libs/langchain/langchain/embeddings/openai.py b/libs/langchain/langchain/embeddings/openai.py index 976c879f95e25..88cb7c9332639 100644 --- a/libs/langchain/langchain/embeddings/openai.py +++ b/libs/langchain/langchain/embeddings/openai.py @@ -87,8 +87,8 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: # https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings -def _check_response(response: dict) -> dict: - if any(len(d["embedding"]) == 1 for d in response["data"]): +def _check_response(response: dict, skip_empty: bool = False) -> dict: + if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty: import openai raise openai.error.APIError("OpenAI API returned an empty embedding") @@ -102,7 +102,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: response = embeddings.client.create(**kwargs) - return _check_response(response) + return _check_response(response, skip_empty=embeddings.skip_empty) return _embed_with_retry(**kwargs) @@ -113,7 +113,7 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> @_async_retry_decorator(embeddings) async def _async_embed_with_retry(**kwargs: Any) -> Any: response = await embeddings.client.acreate(**kwargs) - return _check_response(response) + return _check_response(response, skip_empty=embeddings.skip_empty) return await _async_embed_with_retry(**kwargs) @@ -196,6 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Whether to show a progress bar when embedding.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" + skip_empty: bool = False + """Whether to skip empty strings when embedding or raise an error. + Defaults to not skipping.""" class Config: """Configuration for this pydantic object.""" @@ -371,6 +374,8 @@ def _get_len_safe_embeddings( results: List[List[List[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): + if self.skip_empty and len(batched_embeddings[i]) == 1: + continue results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i]))