From da1f245a84339068dfadbe41e1e4888e10f35b24 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Tue, 6 Jun 2023 17:06:10 +0200 Subject: [PATCH] feat: Add batch_size parameter and cast timeout_config value to tuple for `WeaviateDocumentStore` (#5079) * Add batch_size parameter and cast timeout_config to tuple * Add unit test * Remove debug tqdm * Remove debug tqdm introduced in #5063 --- haystack/document_stores/search_engine.py | 2 +- haystack/document_stores/weaviate.py | 32 +++++++++++++++++------ test/document_stores/test_weaviate.py | 6 +++++ 3 files changed, 31 insertions(+), 9 deletions(-) diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index dcea7ed728..5e6805e627 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -1310,7 +1310,7 @@ def query_by_embedding_batch( body = [] all_documents = [] - for query_emb, cur_filters in tqdm(zip(query_embs, filters)): + for query_emb, cur_filters in zip(query_embs, filters): cur_query_body = self._construct_dense_query_body( query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding ) diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index 048f9ea348..477e84d072 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -97,6 +97,7 @@ def __init__( duplicate_documents: str = "overwrite", recreate_index: bool = False, replication_factor: int = 1, + batch_size: int = 10_000, ): """ :param host: Weaviate server connection URL for storing and processing documents and vectors. @@ -138,6 +139,7 @@ def __init__( lost if you choose to recreate the index. :param replication_factor: Sets the Weaviate Class's replication factor in Weaviate at the time of Class creation. See also [Weaviate documentation](https://weaviate.io/developers/weaviate/current/configuration/replication.html). + :param batch_size: The number of documents to index at once. """ super().__init__() @@ -146,6 +148,9 @@ def __init__( secret = self._get_auth_secret( username, password, client_secret, access_token, expires_in, refresh_token, scope ) + # Timeout config can only be defined as a list in YAML, but Weaviate expects a tuple + if isinstance(timeout_config, list): + timeout_config = tuple(timeout_config) self.weaviate_client = client.Client( url=weaviate_url, auth_client_secret=secret, @@ -186,6 +191,7 @@ def __init__( self.progress_bar = progress_bar self.duplicate_documents = duplicate_documents self.replication_factor = replication_factor + self.batch_size = batch_size self._create_schema_and_index(self.index, recreate_index=recreate_index) self.uuid_format_warning_raised = False @@ -400,7 +406,7 @@ def get_documents_by_id( self, ids: List[str], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> List[Document]: """ @@ -410,6 +416,7 @@ def get_documents_by_id( raise NotImplementedError("WeaviateDocumentStore does not support headers.") index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size # We retrieve the JSON properties from the schema and convert them back to the Python dicts json_properties = self._get_json_properties(index=index) documents = [] @@ -557,7 +564,7 @@ def write_documents( self, documents: Union[List[dict], List[Document]], index: Optional[str] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, duplicate_documents: Optional[str] = None, headers: Optional[Dict[str, str]] = None, ): @@ -567,6 +574,7 @@ def write_documents( :param documents: List of `Dicts` or List of `Documents`. A dummy embedding vector for each document is automatically generated if it is not provided. The document id needs to be in uuid format. Otherwise a correctly formatted uuid will be automatically generated based on the provided id. :param index: index name for storing the docs and metadata :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + If no batch_size is provided, self.batch_size is used. :param duplicate_documents: Handle duplicates document based on parameter options. Parameter options : ( 'skip','overwrite','fail') skip: Ignore the duplicates documents @@ -580,6 +588,7 @@ def write_documents( raise NotImplementedError("WeaviateDocumentStore does not support headers.") index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size self._create_schema_and_index(index, recreate_index=False) field_map = self._create_document_field_map() @@ -764,7 +773,7 @@ def get_all_documents( index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> List[Document]: """ @@ -817,11 +826,13 @@ def get_all_documents( ``` :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + If no batch_size is provided, self.batch_size is used. """ if headers: raise NotImplementedError("WeaviateDocumentStore does not support headers.") index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size result = self.get_all_documents_generator( index=index, filters=filters, return_embedding=return_embedding, batch_size=batch_size ) @@ -832,13 +843,14 @@ def _get_all_documents_in_index( self, index: Optional[str], filters: Optional[FilterType] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, only_documents_without_embedding: bool = False, ) -> Generator[dict, None, None]: """ Return all documents in a specific index in the document store """ index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size # Build the properties to retrieve from Weaviate properties = self._get_current_properties(index) @@ -907,7 +919,7 @@ def get_all_documents_generator( index: Optional[str] = None, filters: Optional[FilterType] = None, return_embedding: Optional[bool] = None, - batch_size: int = 10_000, + batch_size: Optional[int] = None, headers: Optional[Dict[str, str]] = None, ) -> Generator[Document, None, None]: """ @@ -962,11 +974,13 @@ def get_all_documents_generator( ``` :param return_embedding: Whether to return the document embeddings. :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + If no batch_size is provided, self.batch_size is used. """ if headers: raise NotImplementedError("WeaviateDocumentStore does not support headers.") index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size if return_embedding is None: return_embedding = self.return_embedding @@ -1418,11 +1432,11 @@ def update_embeddings( index: Optional[str] = None, filters: Optional[FilterType] = None, update_existing_embeddings: bool = True, - batch_size: int = 10_000, + batch_size: Optional[int] = None, ): """ - Updates the embeddings in the the document store using the encoding model specified in the retriever. - This can be useful if want to change the embeddings for your documents (e.g. after changing the retriever config). + Updates the embeddings in the document store using the encoding model specified in the retriever. + This can be useful if you want to change the embeddings for your documents (e.g. after changing the retriever config). :param retriever: Retriever to use to update the embeddings. :param index: Index name to update @@ -1456,9 +1470,11 @@ def update_embeddings( } ``` :param batch_size: When working with large number of documents, batching can help reduce memory footprint. + If no batch_size is specified, self.batch_size is used. :return: None """ index = self._sanitize_index_name(index) or self.index + batch_size = batch_size or self.batch_size if not self.embedding_field: raise RuntimeError("Specify the arg `embedding_field` when initializing WeaviateDocumentStore()") diff --git a/test/document_stores/test_weaviate.py b/test/document_stores/test_weaviate.py index 74e7993b5b..6b0a967388 100644 --- a/test/document_stores/test_weaviate.py +++ b/test/document_stores/test_weaviate.py @@ -449,3 +449,9 @@ def test_list_of_dict_metadata(self, mocked_ds): ) retrieved_docs = mocked_ds.get_all_documents() assert retrieved_docs[0].meta["list_dict_field"] == [{"key": "value"}, {"key": "value"}] + + @pytest.mark.unit + def test_write_documents_req_for_each_batch(self, mocked_ds, documents): + mocked_ds.batch_size = 2 + mocked_ds.write_documents(documents) + assert mocked_ds.weaviate_client.batch.create_objects.call_count == 5