Skip to content

Commit

Permalink
fix: avoid conflicts with opensearch / elasticsearch magic attributes…
Browse files Browse the repository at this point in the history
… during bulk requests (#5113)

* use _source on opensearch bulk requests

* fix label bulk requests

* add tests

* fix test

* apply feedback

---------

Co-authored-by: Vladimir Blagojevic <dovlex@gmail.com>
  • Loading branch information
tstadel and vblagoje authored Jul 7, 2023
1 parent 13bed30 commit 9acb275
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 35 deletions.
77 changes: 42 additions & 35 deletions haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,31 +415,14 @@ def write_documents(
)
documents_to_index = []
for doc in document_objects:
_doc = {
index_message: Dict[str, Any] = {
"_op_type": "index" if duplicate_documents == "overwrite" else "create",
"_index": index,
**doc.to_dict(field_map=self._create_document_field_map()),
} # type: Dict[str, Any]

# cast embedding type as ES cannot deal with np.array
if _doc[self.embedding_field] is not None:
if type(_doc[self.embedding_field]) == np.ndarray:
_doc[self.embedding_field] = _doc[self.embedding_field].tolist()

# rename id for elastic
_doc["_id"] = str(_doc.pop("id"))

# don't index query score and empty fields
_ = _doc.pop("score", None)
_doc = {k: v for k, v in _doc.items() if v is not None}

# In order to have a flat structure in elastic + similar behaviour to the other DocumentStores,
# we "unnest" all value within "meta"
if "meta" in _doc.keys():
for k, v in _doc["meta"].items():
_doc[k] = v
_doc.pop("meta")
documents_to_index.append(_doc)
"_id": str(doc.id),
# use _source explicitly to avoid conflicts with automatic field detection by ES/OS clients (e.g. "version")
"_source": self._get_source(doc, field_map),
}
documents_to_index.append(index_message)

# Pass batch_size number of documents to bulk
if len(documents_to_index) % batch_size == 0:
Expand All @@ -449,6 +432,27 @@ def write_documents(
if documents_to_index:
self._bulk(documents_to_index, refresh=self.refresh_type, headers=headers)

def _get_source(self, doc: Document, field_map: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a Document object to a dictionary that can be used as the "_source" field in an ES/OS index message."""

_source: Dict[str, Any] = doc.to_dict(field_map=field_map)

# cast embedding type as ES/OS cannot deal with np.array
if isinstance(_source.get(self.embedding_field), np.ndarray):
_source[self.embedding_field] = _source[self.embedding_field].tolist()

# we already have the id in the index message
_source.pop("id", None)

# don't index query score and empty fields
_source.pop("score", None)
_source = {k: v for k, v in _source.items() if v is not None}

# In order to have a flat structure in ES/OS + similar behavior to the other DocumentStores,
# we "unnest" all value within "meta"
_source.update(_source.pop("meta", None) or {})
return _source

def write_labels(
self,
labels: Union[List[Label], List[dict]],
Expand Down Expand Up @@ -481,24 +485,27 @@ def write_labels(
labels_to_index = []
for label in label_list:
# create timestamps if not available yet
if not label.created_at: # type: ignore
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S") # type: ignore
if not label.updated_at: # type: ignore
label.updated_at = label.created_at # type: ignore
if not label.created_at:
label.created_at = time.strftime("%Y-%m-%d %H:%M:%S")
if not label.updated_at:
label.updated_at = label.created_at

_label = {
index_message: Dict[str, Any] = {
"_op_type": "index"
if self.duplicate_documents == "overwrite" or label.id in duplicate_ids
else "create", # type: ignore
else "create",
"_index": index,
**label.to_dict(), # type: ignore
} # type: Dict[str, Any]
}

_source = label.to_dict()

# rename id for elastic
if label.id is not None: # type: ignore
_label["_id"] = str(_label.pop("id")) # type: ignore
# set id for elastic
if _source.get("id") is not None:
index_message["_id"] = str(_source.pop("id"))

labels_to_index.append(_label)
# use _source explicitly to avoid conflicts with automatic field detection by ES/OS clients (e.g. "version")
index_message["_source"] = _source
labels_to_index.append(index_message)

# Pass batch_size number of labels to bulk
if len(labels_to_index) % batch_size == 0:
Expand Down
24 changes: 24 additions & 0 deletions test/document_stores/test_search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,30 @@ def test_query_by_embedding_batch_req_for_each_batch(self, mocked_document_store
mocked_document_store.query_by_embedding_batch([np.array([1, 2, 3])] * 3)
assert mocked_document_store.client.msearch.call_count == 2

@pytest.mark.integration
def test_document_with_version_metadata(self, ds: SearchEngineDocumentStore):
ds.write_documents([{"content": "test", "meta": {"version": "2023.1"}}])
documents = ds.get_all_documents()
assert documents[0].meta["version"] == "2023.1"

@pytest.mark.integration
def test_label_with_version_metadata(self, ds: SearchEngineDocumentStore):
ds.write_labels(
[
{
"query": "test",
"document": {"content": "test"},
"is_correct_answer": True,
"is_correct_document": True,
"origin": "gold-label",
"meta": {"version": "2023.1"},
"answer": None,
}
]
)
labels = ds.get_all_labels()
assert labels[0].meta["version"] == "2023.1"


@pytest.mark.document_store
class TestSearchEngineDocumentStore:
Expand Down

0 comments on commit 9acb275

Please sign in to comment.