Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add batching for querying in ElasticsearchDocumentStore and OpenSearchDocumentStore #5063

Merged
merged 9 commits into from
Jun 1, 2023
4 changes: 4 additions & 0 deletions haystack/document_stores/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
synonyms: Optional[List] = None,
synonym_type: str = "synonym",
use_system_proxy: bool = False,
batch_size: int = 10_000,
):
"""
A DocumentStore using Elasticsearch to store and query the documents for our search.
Expand Down Expand Up @@ -127,6 +128,8 @@ def __init__(
Synonym or Synonym_graph to handle synonyms, including multi-word synonyms correctly during the analysis process.
More info at https://www.elastic.co/guide/en/elasticsearch/reference/current/analysis-synonym-graph-tokenfilter.html
:param use_system_proxy: Whether to use system proxy.
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
memory issues, decrease the batch_size.

"""
# Base constructor might need the client to be ready, create it first
Expand Down Expand Up @@ -167,6 +170,7 @@ def __init__(
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
synonym_type=synonym_type,
batch_size=batch_size,
)

# Let the base class trap the right exception from the elasticpy client
Expand Down
17 changes: 13 additions & 4 deletions haystack/document_stores/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
knn_engine: str = "nmslib",
knn_parameters: Optional[Dict] = None,
ivf_train_size: Optional[int] = None,
batch_size: int = 10_000,
):
"""
Document Store using OpenSearch (https://opensearch.org/). It is compatible with the Amazon OpenSearch Service.
Expand Down Expand Up @@ -165,6 +166,8 @@ def __init__(
index type and knn parameters). If `0`, training doesn't happen automatically but needs
to be triggered manually via the `train_index` method.
Default: `None`
:param batch_size: Number of Documents to index at once / Number of queries to execute at once. If you face
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write_documents method of the OpenSearchDocumentStore also has a batch_size parameter with a default of 10_000. If we introduce a batch_size param in the init of the document store, we should also use self.batch_size or batch_size in write_documents and make the parameter batch_size: Optional[int] = None in the signature of the write_documents method of the OpenSearchDocumentStore.
Can we make the default 10_000 instead of 1_000 then to prevent a breaking change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed write_documents and changed the default back to 10_000. I initially changed it to 1000 because I found this on the Elasticsearch documentation, but this information is probably outdated as it is for a former version.

memory issues, decrease the batch_size.
"""
# These parameters aren't used by Opensearch at the moment but could be in the future, see
# https://github.com/opensearch-project/security/issues/1504. Let's not deprecate them for
Expand Down Expand Up @@ -243,6 +246,7 @@ def __init__(
skip_missing_embeddings=skip_missing_embeddings,
synonyms=synonyms,
synonym_type=synonym_type,
batch_size=batch_size,
)

# Let the base class catch the right error from the Opensearch client
Expand Down Expand Up @@ -321,7 +325,7 @@ def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
Expand Down Expand Up @@ -358,6 +362,8 @@ def write_documents(
if index is None:
index = self.index

batch_size = batch_size or self.batch_size

if self.knn_engine == "faiss" and self.similarity == "cosine":
field_map = self._create_document_field_map()
documents = [Document.from_dict(d, field_map=field_map) if isinstance(d, dict) else d for d in documents]
Expand Down Expand Up @@ -529,6 +535,7 @@ def query_by_embedding_batch(
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
batch_size: Optional[int] = None,
) -> List[List[Document]]:
"""
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
Expand Down Expand Up @@ -605,17 +612,19 @@ def query_by_embedding_batch(
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
:return:
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
:param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used.
"""
if index is None:
index = self.index

batch_size = batch_size or self.batch_size

if self.index_type in ["ivf", "ivf_pq"] and not self._ivf_model_exists(index=index):
self._ivf_index_not_trained_error(index=index, headers=headers)

return super().query_by_embedding_batch(
query_embs, filters, top_k, index, return_embedding, headers, scale_score
query_embs, filters, top_k, index, return_embedding, headers, scale_score, batch_size
)

def query(
Expand Down
76 changes: 50 additions & 26 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
skip_missing_embeddings: bool = True,
synonyms: Optional[List] = None,
synonym_type: str = "synonym",
batch_size: int = 10_000,
):
super().__init__()

Expand Down Expand Up @@ -98,6 +99,7 @@ def __init__(
self.skip_missing_embeddings: bool = skip_missing_embeddings
self.duplicate_documents = duplicate_documents
self.refresh_type = refresh_type
self.batch_size = batch_size
if similarity in ["cosine", "dot_product", "l2"]:
self.similarity: str = similarity
else:
Expand Down Expand Up @@ -367,7 +369,7 @@ def write_documents(
self,
documents: Union[List[dict], List[Document]],
index: Optional[str] = None,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
duplicate_documents: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
):
Expand All @@ -390,6 +392,7 @@ def write_documents(
to what you have set for self.content_field and self.name_field.
:param index: search index where the documents should be indexed. If you don't specify it, self.index is used.
:param batch_size: Number of documents that are passed to the bulk function at each round.
If not specified, self.batch_size is used.
:param duplicate_documents: Handle duplicate documents based on parameter options.
Parameter options: ( 'skip','overwrite','fail')
skip: Ignore the duplicate documents
Expand All @@ -407,6 +410,9 @@ def write_documents(

if index is None:
index = self.index

batch_size = batch_size or self.batch_size

duplicate_documents = duplicate_documents or self.duplicate_documents
assert (
duplicate_documents in self.duplicate_documents_options
Expand Down Expand Up @@ -923,9 +929,10 @@ def query_batch(
headers: Optional[Dict[str, str]] = None,
all_terms_must_match: bool = False,
scale_score: bool = True,
batch_size: Optional[int] = None,
) -> List[List[Document]]:
"""
Scan through documents in DocumentStore and return a small number documents
Scan through documents in DocumentStore and return a small number of documents
that are most relevant to the provided queries as defined by keyword matching algorithms like BM25.

This method lets you find relevant documents for list of query strings (output: List of Lists of Documents).
Expand Down Expand Up @@ -1005,17 +1012,19 @@ def query_batch(
:param headers: Custom HTTP headers to pass to document store client if supported (e.g. {'Authorization': 'Basic YWRtaW46cm9vdA=='} for basic authentication)
:param all_terms_must_match: Whether all terms of the query must match the document.
If true all query terms must be present in a document in order to be retrieved (i.e the AND operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy AND fish AND restaurant").
Otherwise at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
Otherwise, at least one query term must be present in a document in order to be retrieved (i.e the OR operator is being used implicitly between query terms: "cozy fish restaurant" -> "cozy OR fish OR restaurant").
Defaults to False.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
:param batch_size: Number of queries that are processed at once. If not specified, self.batch_size is used.
"""

if index is None:
index = self.index
if headers is None:
headers = {}
batch_size = batch_size or self.batch_size

if isinstance(filters, list):
if len(filters) != len(queries):
Expand All @@ -1027,6 +1036,7 @@ def query_batch(
filters = [filters] * len(queries)

body = []
all_documents = []
for query, cur_filters in zip(queries, filters):
cur_query_body = self._construct_query_body(
query=query,
Expand All @@ -1038,17 +1048,27 @@ def query_batch(
body.append(headers)
body.append(cur_query_body)

responses = self.client.msearch(index=index, body=body)
if len(body) == 2 * batch_size:
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
all_documents.extend(cur_documents)
body = []

all_documents = []
cur_documents = []
for response in responses["responses"]:
cur_result = response["hits"]["hits"]
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in cur_result]
all_documents.append(cur_documents)
if len(body) > 0:
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
all_documents.extend(cur_documents)

return all_documents

def _execute_msearch(self, index: str, body: List[Dict[str, Any]], scale_score: bool) -> List[List[Document]]:
responses = self.client.msearch(index=index, body=body)
documents = []
for response in responses["responses"]:
result = response["hits"]["hits"]
cur_documents = [self._convert_es_hit_to_document(hit, scale_score=scale_score) for hit in result]
documents.append(cur_documents)

return documents

def _construct_query_body(
self,
query: Optional[str],
Expand Down Expand Up @@ -1188,6 +1208,7 @@ def query_by_embedding_batch(
return_embedding: Optional[bool] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: bool = True,
batch_size: Optional[int] = None,
) -> List[List[Document]]:
"""
Find the documents that are most similar to the provided `query_embs` by using a vector similarity metric.
Expand Down Expand Up @@ -1264,8 +1285,8 @@ def query_by_embedding_batch(
Check out https://www.elastic.co/guide/en/elasticsearch/reference/current/http-clients.html for more information.
:param scale_score: Whether to scale the similarity score to the unit interval (range of [0,1]).
If true (default) similarity scores (e.g. cosine or dot_product) which naturally have a different value range will be scaled to a range of [0,1], where 1 means extremely relevant.
Otherwise raw similarity scores (e.g. cosine or dot_product) will be used.
:return:
Otherwise, raw similarity scores (e.g. cosine or dot_product) will be used.
:param batch_size: Number of query embeddings to process at once. If not specified, self.batch_size is used.
"""
if index is None:
index = self.index
Expand All @@ -1276,6 +1297,8 @@ def query_by_embedding_batch(
if headers is None:
headers = {}

batch_size = batch_size or self.batch_size

if not self.embedding_field:
raise DocumentStoreError("Please set a valid `embedding_field` for OpenSearchDocumentStore")

Expand All @@ -1289,25 +1312,24 @@ def query_by_embedding_batch(
filters = [filters] * len(query_embs) if filters is not None else [{}] * len(query_embs)

body = []
for query_emb, cur_filters in zip(query_embs, filters):
all_documents = []
for query_emb, cur_filters in tqdm(zip(query_embs, filters)):
cur_query_body = self._construct_dense_query_body(
query_emb=query_emb, filters=cur_filters, top_k=top_k, return_embedding=return_embedding
)
body.append(headers)
body.append(cur_query_body)

logger.debug("Retriever query: %s", body)
responses = self.client.msearch(index=index, body=body)
if len(body) >= batch_size * 2:
logger.debug("Retriever query: %s", body)
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
all_documents.extend(cur_documents)
body = []

all_documents = []
cur_documents = []
for response in responses["responses"]:
cur_result = response["hits"]["hits"]
cur_documents = [
self._convert_es_hit_to_document(hit, adapt_score_for_embedding=True, scale_score=scale_score)
for hit in cur_result
]
all_documents.append(cur_documents)
if len(body) > 0:
logger.debug("Retriever query: %s", body)
cur_documents = self._execute_msearch(index=index, body=body, scale_score=scale_score)
all_documents.extend(cur_documents)

return all_documents

Expand All @@ -1323,7 +1345,7 @@ def update_embeddings(
index: Optional[str] = None,
filters: Optional[FilterType] = None,
update_existing_embeddings: bool = True,
batch_size: int = 10_000,
batch_size: Optional[int] = None,
headers: Optional[Dict[str, str]] = None,
):
"""
Expand Down Expand Up @@ -1370,6 +1392,8 @@ def update_embeddings(
if index is None:
index = self.index

batch_size = batch_size or self.batch_size

if self.refresh_type == "false":
self.client.indices.refresh(index=index, headers=headers)

Expand Down
9 changes: 8 additions & 1 deletion test/document_stores/test_elasticsearch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -344,3 +344,10 @@ def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_doc
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["_source"] == {"excludes": ["embedding"]}

@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
mocked_document_store.batch_size = 2
with patch("haystack.document_stores.elasticsearch.bulk") as mocked_bulk:
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5
7 changes: 7 additions & 0 deletions test/document_stores/test_opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,3 +1291,10 @@ def test_get_document_by_id_excluded_meta_data_has_no_influence(self, mocked_doc
# assert the resulting body is not affected by the `excluded_meta_data` value
_, kwargs = mocked_document_store.client.search.call_args
assert kwargs["body"]["_source"] == {"excludes": ["embedding"]}

@pytest.mark.unit
def test_write_documents_req_for_each_batch(self, mocked_document_store, documents):
mocked_document_store.batch_size = 2
with patch("haystack.document_stores.opensearch.bulk") as mocked_bulk:
mocked_document_store.write_documents(documents)
assert mocked_bulk.call_count == 5
14 changes: 14 additions & 0 deletions test/document_stores/test_search_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from unittest.mock import MagicMock

import numpy as np
import pytest
from haystack.document_stores.search_engine import SearchEngineDocumentStore, prepare_hosts

Expand Down Expand Up @@ -167,6 +169,18 @@ def test_get_all_labels_legacy_document_id(self, mocked_document_store, mocked_g
labels = mocked_document_store.get_all_labels()
assert labels[0].answer.document_ids == ["fc18c987a8312e72a47fb1524f230bb0"]

@pytest.mark.unit
def test_query_batch_req_for_each_batch(self, mocked_document_store):
mocked_document_store.batch_size = 2
mocked_document_store.query_batch([self.query] * 3)
assert mocked_document_store.client.msearch.call_count == 2

@pytest.mark.unit
def test_query_by_embedding_batch_req_for_each_batch(self, mocked_document_store):
mocked_document_store.batch_size = 2
mocked_document_store.query_by_embedding_batch([np.array([1, 2, 3])] * 3)
assert mocked_document_store.client.msearch.call_count == 2


@pytest.mark.document_store
class TestSearchEngineDocumentStore:
Expand Down