Skip to content

Commit

Permalink
fix: make JoinDocuments correctly handle duplicate documents w null…
Browse files Browse the repository at this point in the history
… scores (#6261)

* fix error with null values

* release note

* simplify
  • Loading branch information
anakin87 authored Nov 9, 2023
1 parent 676da68 commit 2b3c77e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
6 changes: 3 additions & 3 deletions haystack/nodes/other/join_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class JoinDocuments(JoinNode):
A node to join documents outputted by multiple retriever nodes.
The node allows multiple join modes:
* concatenate: combine the documents from multiple nodes. Any duplicate documents are discarded.
The score is only determined by the last node that outputs the document.
* concatenate: combine the documents from multiple nodes.
In case of duplicate documents, the one with the highest score is kept.
* merge: merge scores of documents from multiple nodes. Optionally, each input score can be given a different
`weight` & a `top_k` limit can be set. This mode can also be used for "reranking" retrieved documents.
* reciprocal_rank_fusion: combines the documents based on their rank in multiple nodes.
Expand Down Expand Up @@ -130,7 +130,7 @@ def _concatenate_results(self, results, document_map):
for doc in result:
if doc.id == idx:
tmp.append(doc)
item_best_score = max(tmp, key=lambda x: x.score)
item_best_score = max(tmp, key=lambda x: x.score if x.score is not None else -inf)
scores_map.update({idx: item_best_score.score})
return scores_map

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
fixes:
- |
When using `JoinDocuments` with `join_mode=concatenate` (default) and
passing duplicate documents, including some with a null score, this
node raised an exception.
Now this case is handled correctly and the documents are joined as expected.
35 changes: 35 additions & 0 deletions test/nodes/test_join_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,38 @@ def test_joindocuments_concatenate_keep_only_highest_ranking_duplicate():
result, _ = join_docs.run(inputs)
assert len(result["documents"]) == 2
assert result["documents"] == expected_outputs["documents"]


@pytest.mark.unit
def test_joindocuments_concatenate_duplicate_docs_null_score():
"""
Test that the concatenate method correctly handles duplicate documents,
when one has a null score.
"""
inputs = [
{
"documents": [
Document(content="text document 1", content_type="text", score=0.2),
Document(content="text document 2", content_type="text", score=0.3),
Document(content="text document 3", content_type="text", score=None),
]
},
{
"documents": [
Document(content="text document 2", content_type="text", score=0.7),
Document(content="text document 1", content_type="text", score=None),
]
},
]
expected_outputs = {
"documents": [
Document(content="text document 2", content_type="text", score=0.7),
Document(content="text document 1", content_type="text", score=0.2),
Document(content="text document 3", content_type="text", score=None),
]
}

join_docs = JoinDocuments(join_mode="concatenate")
result, _ = join_docs.run(inputs)
assert len(result["documents"]) == 3
assert result["documents"] == expected_outputs["documents"]

0 comments on commit 2b3c77e

Please sign in to comment.