From 5dbae94e04600cf0a868485e532a4c3c15290846 Mon Sep 17 00:00:00 2001 From: ElReyZero Date: Mon, 4 Sep 2023 16:10:36 -0500 Subject: [PATCH] OpenAIEmbeddings: Add optional an optional parameter to skip empty embeddings (#10196) ## Description ### Issue This pull request addresses a lingering issue identified in PR #7070. In that previous pull request, an attempt was made to address the problem of empty embeddings when using the `OpenAIEmbeddings` class. While PR #7070 introduced a mechanism to retry requests for embeddings, it didn't fully resolve the issue as empty embeddings still occasionally persisted. ### Problem In certain specific use cases, empty embeddings can be encountered when requesting data from the OpenAI API. In some cases, these empty embeddings can be skipped or removed without affecting the functionality of the application. However, they might not always be resolved through retries, and their presence can adversely affect the functionality of applications relying on the `OpenAIEmbeddings` class. ### Solution To provide a more robust solution for handling empty embeddings, we propose the introduction of an optional parameter, `skip_empty`, in the `OpenAIEmbeddings` class. When set to `True`, this parameter will enable the behavior of automatically skipping empty embeddings, ensuring that problematic empty embeddings do not disrupt the processing flow. The developer will be able to optionally toggle this behavior if needed without disrupting the application flow. ## Changes Made - Added an optional parameter, `skip_empty`, to the `OpenAIEmbeddings` class. - When `skip_empty` is set to `True`, empty embeddings are automatically skipped without causing errors or disruptions. ### Example Usage ```python from openai.embeddings import OpenAIEmbeddings # Initialize the OpenAIEmbeddings class with skip_empty=True embeddings = OpenAIEmbeddings(api_key="your_api_key", skip_empty=True) # Request embeddings, empty embeddings are automatically skipped. docs is a variable containing the already splitted text. results = embeddings.embed_documents(docs) # Process results without interruption from empty embeddings ``` --- libs/langchain/langchain/embeddings/openai.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/embeddings/openai.py b/libs/langchain/langchain/embeddings/openai.py index 976c879f95e25..88cb7c9332639 100644 --- a/libs/langchain/langchain/embeddings/openai.py +++ b/libs/langchain/langchain/embeddings/openai.py @@ -87,8 +87,8 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable: # https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings -def _check_response(response: dict) -> dict: - if any(len(d["embedding"]) == 1 for d in response["data"]): +def _check_response(response: dict, skip_empty: bool = False) -> dict: + if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty: import openai raise openai.error.APIError("OpenAI API returned an empty embedding") @@ -102,7 +102,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any: @retry_decorator def _embed_with_retry(**kwargs: Any) -> Any: response = embeddings.client.create(**kwargs) - return _check_response(response) + return _check_response(response, skip_empty=embeddings.skip_empty) return _embed_with_retry(**kwargs) @@ -113,7 +113,7 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> @_async_retry_decorator(embeddings) async def _async_embed_with_retry(**kwargs: Any) -> Any: response = await embeddings.client.acreate(**kwargs) - return _check_response(response) + return _check_response(response, skip_empty=embeddings.skip_empty) return await _async_embed_with_retry(**kwargs) @@ -196,6 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings): """Whether to show a progress bar when embedding.""" model_kwargs: Dict[str, Any] = Field(default_factory=dict) """Holds any model parameters valid for `create` call not explicitly specified.""" + skip_empty: bool = False + """Whether to skip empty strings when embedding or raise an error. + Defaults to not skipping.""" class Config: """Configuration for this pydantic object.""" @@ -371,6 +374,8 @@ def _get_len_safe_embeddings( results: List[List[List[float]]] = [[] for _ in range(len(texts))] num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))] for i in range(len(indices)): + if self.skip_empty and len(batched_embeddings[i]) == 1: + continue results[indices[i]].append(batched_embeddings[i]) num_tokens_in_batch[indices[i]].append(len(tokens[i]))