From ac44e58d0a94b9f571f0ca41e004af31dcef3b1b Mon Sep 17 00:00:00 2001 From: Daniel Buades Marcos Date: Sat, 7 Dec 2024 23:02:15 +0100 Subject: [PATCH] fix(bm25s): search implementation (#1566) fix: bm25s implementation --- mteb/evaluation/evaluators/RetrievalEvaluator.py | 6 +++--- mteb/models/bm25.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mteb/evaluation/evaluators/RetrievalEvaluator.py b/mteb/evaluation/evaluators/RetrievalEvaluator.py index 20a29b3ad..8ec28c14e 100644 --- a/mteb/evaluation/evaluators/RetrievalEvaluator.py +++ b/mteb/evaluation/evaluators/RetrievalEvaluator.py @@ -477,10 +477,10 @@ def __call__( if self.is_cross_encoder: return self.retriever.search_cross_encoder(corpus, queries, self.top_k) elif ( - hasattr(self.retriever.model, "mteb_model_meta") - and self.retriever.model.mteb_model_meta.name == "bm25s" + hasattr(self.retriever.model.model, "mteb_model_meta") + and self.retriever.model.model.mteb_model_meta.name == "bm25s" ): - return self.retriever.model.search( + return self.retriever.model.model.search( corpus, queries, self.top_k, diff --git a/mteb/models/bm25.py b/mteb/models/bm25.py index 1848b9e4e..fdc86fb21 100644 --- a/mteb/models/bm25.py +++ b/mteb/models/bm25.py @@ -58,7 +58,17 @@ def search( ) -> dict[str, dict[str, float]]: logger.info("Encoding Corpus...") corpus_ids = list(corpus.keys()) - corpus_with_ids = [{"doc_id": cid, **corpus[cid]} for cid in corpus_ids] + corpus_with_ids = [ + { + "doc_id": cid, + **( + {"text": corpus[cid]} + if isinstance(corpus[cid], str) + else corpus[cid] + ), + } + for cid in corpus_ids + ] corpus_texts = [ "\n".join([doc.get("title", ""), doc["text"]])