Skip to content

Commit

Permalink
feat: Add batch_size parameter and cast timeout_config value to tuple…
Browse files Browse the repository at this point in the history
… for `WeaviateDocumentStore` (#5079)

* Add batch_size parameter and cast timeout_config to tuple

* Add unit test

* Remove debug tqdm

* Remove debug tqdm introduced in #5063
  • Loading branch information
bogdankostic authored Jun 6, 2023
1 parent 1777b22 commit da1f245
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
2 changes: 1 addition & 1 deletion haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ def query_by_embedding_batch(

body = []
all_documents = []
for query_emb, cur_filters in tqdm(zip(query_embs, filters)):
for query_emb, cur_filters in zip(query_embs, filters):
cur_query_body = self._construct_dense_query_body(
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
)
Expand Down
32 changes: 24 additions & 8 deletions haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
duplicate_documents: str = "overwrite",
recreate_index: bool = False,
replication_factor: int = 1,
batch_size: int = 10_000,
):
"""
:param host: Weaviate server connection URL for storing and processing documents and vectors.
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
lost if you choose to recreate the index.
:param replication_factor: Sets the Weaviate Class's replication factor in Weaviate at the time of Class creation.
See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/configuration/replication.html).
:param batch_size: The number of documents to index at once.
"""
super().__init__()

Expand All @@ -146,6 +148,9 @@ def __init__(
secret = self._get_auth_secret(
username, password, client_secret, access_token, expires_in, refresh_token, scope
)
# Timeout config can only be defined as a list in YAML, but Weaviate expects a tuple
if isinstance(timeout_config, list):
timeout_config = tuple(timeout_config)
self.weaviate_client = client.Client(
url=weaviate_url,
auth_client_secret=secret,
Expand Down Expand Up @@ -186,6 +191,7 @@ def __init__(
self.progress_bar = progress_bar
self.duplicate_documents = duplicate_documents
self.replication_factor = replication_factor
self.batch_size = batch_size

self._create_schema_and_index(self.index, recreate_index=recreate_index)
self.uuid_format_warning_raised = False
Expand Down Expand Up @@ -400,7 +406,7 @@ def get_documents_by_id(
self,
ids: List[str],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
Expand All @@ -410,6 +416,7 @@ def get_documents_by_id(
raise NotImplementedError("WeaviateDocumentStore does not support headers.")

index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
# We retrieve the JSON properties from the schema and convert them back to the Python dicts
json_properties = self._get_json_properties(index=index)
documents = []
Expand Down Expand Up @@ -557,7 +564,7 @@ def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
Expand All @@ -567,6 +574,7 @@ def write_documents(
:param documents: List of `Dicts` or List of `Documents`. A dummy embedding vector for each document is automatically generated if it is not provided. The document id needs to be in uuid format. Otherwise a correctly formatted uuid will be automatically generated based on the provided id.
:param index: index name for storing the docs and metadata
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
:param duplicate_documents: Handle duplicates document based on parameter options.
Parameter options : ( 'skip','overwrite','fail')
skip: Ignore the duplicates documents
Expand All @@ -580,6 +588,7 @@ def write_documents(
raise NotImplementedError("WeaviateDocumentStore does not support headers.")

index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
self._create_schema_and_index(index, recreate_index=False)
field_map = self._create_document_field_map()

Expand Down Expand Up @@ -764,7 +773,7 @@ def get_all_documents(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""
Expand Down Expand Up @@ -817,11 +826,13 @@ def get_all_documents(
```
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
"""
if headers:
raise NotImplementedError("WeaviateDocumentStore does not support headers.")

index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size
result = self.get_all_documents_generator(
index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size
)
Expand All @@ -832,13 +843,14 @@ def _get_all_documents_in_index(
self,
index: Optional[str],
filters: Optional[FilterType] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
only_documents_without_embedding: bool = False,
) -> Generator[dict, None, None]:
"""
Return all documents in a specific index in the document store
"""
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size

# Build the properties to retrieve from Weaviate
properties = self._get_current_properties(index)
Expand Down Expand Up @@ -907,7 +919,7 @@ def get_all_documents_generator(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
return_embedding: Optional[bool] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
) -> Generator[Document, None, None]:
"""
Expand Down Expand Up @@ -962,11 +974,13 @@ def get_all_documents_generator(
```
:param return_embedding: Whether to return the document embeddings.
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is provided, self.batch_size is used.
"""
if headers:
raise NotImplementedError("WeaviateDocumentStore does not support headers.")

index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size

if return_embedding is None:
return_embedding = self.return_embedding
Expand Down Expand Up @@ -1418,11 +1432,11 @@ def update_embeddings(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
update_existing_embeddings: bool = True,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
):
"""
Updates the embeddings in the the document store using the encoding model specified in the retriever.
This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config).
Updates the embeddings in the document store using the encoding model specified in the retriever.
This can be useful if you want to change the embeddings for your documents (e.g. after changing the retriever config).
:param retriever: Retriever to use to update the embeddings.
:param index: Index name to update
Expand Down Expand Up @@ -1456,9 +1470,11 @@ def update_embeddings(
}
```
:param batch_size: When working with large number of documents, batching can help reduce memory footprint.
If no batch_size is specified, self.batch_size is used.
:return: None
"""
index = self._sanitize_index_name(index) or self.index
batch_size = batch_size or self.batch_size

if not self.embedding_field:
raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()")
Expand Down
6 changes: 6 additions & 0 deletions test/document_stores/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,9 @@ def test_list_of_dict_metadata(self, mocked_ds):
)
retrieved_docs = mocked_ds.get_all_documents()
assert retrieved_docs[0].meta["list_dict_field"] == [{"key": "value"}, {"key": "value"}]

@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_ds, documents):
mocked_ds.batch_size = 2
mocked_ds.write_documents(documents)
assert mocked_ds.weaviate_client.batch.create_objects.call_count == 5

0 comments on commit da1f245

Please sign in to comment.