Skip to content

Commit

Permalink
OpenAIEmbeddings: Add optional an optional parameter to skip empty em…
Browse files Browse the repository at this point in the history
…beddings (#10196)

## Description

### Issue
This pull request addresses a lingering issue identified in PR #7070. In
that previous pull request, an attempt was made to address the problem
of empty embeddings when using the `OpenAIEmbeddings` class. While PR
#7070 introduced a mechanism to retry requests for embeddings, it didn't
fully resolve the issue as empty embeddings still occasionally
persisted.

### Problem
In certain specific use cases, empty embeddings can be encountered when
requesting data from the OpenAI API. In some cases, these empty
embeddings can be skipped or removed without affecting the functionality
of the application. However, they might not always be resolved through
retries, and their presence can adversely affect the functionality of
applications relying on the `OpenAIEmbeddings` class.

### Solution
To provide a more robust solution for handling empty embeddings, we
propose the introduction of an optional parameter, `skip_empty`, in the
`OpenAIEmbeddings` class. When set to `True`, this parameter will enable
the behavior of automatically skipping empty embeddings, ensuring that
problematic empty embeddings do not disrupt the processing flow. The
developer will be able to optionally toggle this behavior if needed
without disrupting the application flow.

## Changes Made
- Added an optional parameter, `skip_empty`, to the `OpenAIEmbeddings`
class.
- When `skip_empty` is set to `True`, empty embeddings are automatically
skipped without causing errors or disruptions.

### Example Usage
```python
from openai.embeddings import OpenAIEmbeddings

# Initialize the OpenAIEmbeddings class with skip_empty=True
embeddings = OpenAIEmbeddings(api_key="your_api_key", skip_empty=True)

# Request embeddings, empty embeddings are automatically skipped. docs is a variable containing the already splitted text.
results = embeddings.embed_documents(docs)

# Process results without interruption from empty embeddings
```
  • Loading branch information
ElReyZero authored Sep 4, 2023
1 parent 8998060 commit 5dbae94
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions libs/langchain/langchain/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ async def wrapped_f(*args: Any, **kwargs: Any) -> Callable:


# https://stackoverflow.com/questions/76469415/getting-embeddings-of-length-1-from-langchain-openaiembeddings
def _check_response(response: dict) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]):
def _check_response(response: dict, skip_empty: bool = False) -> dict:
if any(len(d["embedding"]) == 1 for d in response["data"]) and not skip_empty:
import openai

raise openai.error.APIError("OpenAI API returned an empty embedding")
Expand All @@ -102,7 +102,7 @@ def embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) -> Any:
@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
response = embeddings.client.create(**kwargs)
return _check_response(response)
return _check_response(response, skip_empty=embeddings.skip_empty)

return _embed_with_retry(**kwargs)

Expand All @@ -113,7 +113,7 @@ async def async_embed_with_retry(embeddings: OpenAIEmbeddings, **kwargs: Any) ->
@_async_retry_decorator(embeddings)
async def _async_embed_with_retry(**kwargs: Any) -> Any:
response = await embeddings.client.acreate(**kwargs)
return _check_response(response)
return _check_response(response, skip_empty=embeddings.skip_empty)

return await _async_embed_with_retry(**kwargs)

Expand Down Expand Up @@ -196,6 +196,9 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"""Whether to show a progress bar when embedding."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
skip_empty: bool = False
"""Whether to skip empty strings when embedding or raise an error.
Defaults to not skipping."""

class Config:
"""Configuration for this pydantic object."""
Expand Down Expand Up @@ -371,6 +374,8 @@ def _get_len_safe_embeddings(
results: List[List[List[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]
for i in range(len(indices)):
if self.skip_empty and len(batched_embeddings[i]) == 1:
continue
results[indices[i]].append(batched_embeddings[i])
num_tokens_in_batch[indices[i]].append(len(tokens[i]))

Expand Down

1 comment on commit 5dbae94

@rabiaedayilmaz
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After using this version, on my code i get this error: AttributeError: 'OpenAIEmbeddings' object has no attribute 'skip_empty'. And here is the code block where error happens:

  def ask(self, user_prompt):
      content = self.prompt + "\n\n" + "Question: " + user_prompt + "\n\n"

      chain = RetrievalQAWithSourcesChain.from_chain_type(llm=self.llm_model,
                                                          retriever=self.store.as_retriever())
      gpt_response = chain(content)['answer']
      return self.control_response_has_source(gpt_response)

Please sign in to comment.