Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify retrieval #233

Merged
merged 13 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ jobs:
python -m pip install flake8 pytest
python -m pip install .
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements.txt ]; then pip install beir; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
# Result folder
results/

# BeIR datasets folder
datasets/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
62 changes: 28 additions & 34 deletions README.md

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def load_data(self, **kwargs):
"""
Load dataset from HuggingFace hub
"""
if self.data_loaded:
return
if self.data_loaded: return

# TODO: add split argument
self.dataset = datasets.load_dataset(
Expand Down
216 changes: 127 additions & 89 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,126 @@
import logging
import json
import os
from collections import defaultdict
from time import time
from typing import Dict, List
from typing import Dict, Tuple

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, WordEmbeddings
from datasets import load_dataset, Value, Features

from ..evaluation.evaluators import RetrievalEvaluator
from .AbsTask import AbsTask

logger = logging.getLogger(__name__)

DRES_METHODS = ["encode_queries", "encode_corpus"]
# Adapted from https://github.com/beir-cellar/beir/blob/f062f038c4bfd19a8ca942a9910b1e0d218759d4/beir/datasets/data_loader_hf.py#L10
class HFDataLoader:
def __init__(self, hf_repo: str = None, hf_repo_qrels: str = None, data_folder: str = None, prefix: str = None, corpus_file: str = "corpus.jsonl", query_file: str = "queries.jsonl",
qrels_folder: str = "qrels", qrels_file: str = "", streaming: bool = False, keep_in_memory: bool = False):
self.corpus = {}
self.queries = {}
self.qrels = {}
self.hf_repo = hf_repo
if hf_repo:
# By default fetch qrels from same repo not a second repo with "-qrels" like in original
self.hf_repo_qrels = hf_repo_qrels if hf_repo_qrels else hf_repo
else:
# data folder would contain these files:
# (1) fiqa/corpus.jsonl (format: jsonlines)
# (2) fiqa/queries.jsonl (format: jsonlines)
# (3) fiqa/qrels/test.tsv (format: tsv ("\t"))
if prefix:
query_file = prefix + "-" + query_file
qrels_folder = prefix + "-" + qrels_folder

self.corpus_file = os.path.join(data_folder, corpus_file) if data_folder else corpus_file
self.query_file = os.path.join(data_folder, query_file) if data_folder else query_file
self.qrels_folder = os.path.join(data_folder, qrels_folder) if data_folder else None
self.qrels_file = qrels_file
self.streaming = streaming
self.keep_in_memory = keep_in_memory

@staticmethod
def check(fIn: str, ext: str):
if not os.path.exists(fIn):
raise ValueError("File {} not present! Please provide accurate file.".format(fIn))

if not fIn.endswith(ext):
raise ValueError("File {} must be present with extension {}".format(fIn, ext))

def load(self, split="test") -> Tuple[Dict[str, Dict[str, str]], Dict[str, str], Dict[str, Dict[str, int]]]:
if not self.hf_repo:
self.qrels_file = os.path.join(self.qrels_folder, split + ".tsv")
self.check(fIn=self.corpus_file, ext="jsonl")
self.check(fIn=self.query_file, ext="jsonl")
self.check(fIn=self.qrels_file, ext="tsv")

if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d %s Documents.", len(self.corpus), split.upper())
logger.info("Doc Example: %s", self.corpus[0])

if not len(self.queries):
logger.info("Loading Queries...")
self._load_queries()

self._load_qrels(split)
# filter queries with no qrels
qrels_dict = defaultdict(dict)

def qrels_dict_init(row):
qrels_dict[row['query-id']][row['corpus-id']] = int(row['score'])
self.qrels.map(qrels_dict_init)
self.qrels = qrels_dict
self.queries = self.queries.filter(lambda x: x['id'] in self.qrels)
logger.info("Loaded %d %s Queries.", len(self.queries), split.upper())
logger.info("Query Example: %s", self.queries[0])

return self.corpus, self.queries, self.qrels

def load_corpus(self) -> Dict[str, Dict[str, str]]:
if not self.hf_repo:
self.check(fIn=self.corpus_file, ext="jsonl")

if not len(self.corpus):
logger.info("Loading Corpus...")
self._load_corpus()
logger.info("Loaded %d %s Documents.", len(self.corpus))
logger.info("Doc Example: %s", self.corpus[0])

return self.corpus

def _load_corpus(self):
if self.hf_repo:
corpus_ds = load_dataset(self.hf_repo, 'corpus', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
else:
corpus_ds = load_dataset('json', data_files=self.corpus_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
corpus_ds = next(iter(corpus_ds.values())) # get first split
corpus_ds = corpus_ds.cast_column('_id', Value('string'))
corpus_ds = corpus_ds.rename_column('_id', 'id')
corpus_ds = corpus_ds.remove_columns([col for col in corpus_ds.column_names if col not in ['id', 'text', 'title']])
self.corpus = corpus_ds

def _load_queries(self):
if self.hf_repo:
queries_ds = load_dataset(self.hf_repo, 'queries', keep_in_memory=self.keep_in_memory, streaming=self.streaming)
else:
queries_ds = load_dataset('json', data_files=self.query_file, streaming=self.streaming, keep_in_memory=self.keep_in_memory)
queries_ds = next(iter(queries_ds.values())) # get first split
queries_ds = queries_ds.cast_column('_id', Value('string'))
queries_ds = queries_ds.rename_column('_id', 'id')
queries_ds = queries_ds.remove_columns([col for col in queries_ds.column_names if col not in ['id', 'text']])
self.queries = queries_ds

def _load_qrels(self, split):
if self.hf_repo:
qrels_ds = load_dataset(self.hf_repo_qrels, keep_in_memory=self.keep_in_memory, streaming=self.streaming)[split]
else:
qrels_ds = load_dataset('csv', data_files=self.qrels_file, delimiter='\t', keep_in_memory=self.keep_in_memory)
features = Features({'query-id': Value('string'), 'corpus-id': Value('string'), 'score': Value('float')})
qrels_ds = qrels_ds.cast(features)
self.qrels = qrels_ds


class AbsTaskRetrieval(AbsTask):
"""
Expand All @@ -21,63 +130,29 @@ class AbsTaskRetrieval(AbsTask):
self.queries = Dict[id, str] #id => query
self.relevant_docs = List[id, id, score]
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

@staticmethod
def is_dres_compatible(model):
for method in DRES_METHODS:
op = getattr(model, method, None)
if not (callable(op)):
return False
return True
def load_data(self, **kwargs):
if self.data_loaded: return
self.corpus, self.queries, self.relevant_docs = {}, {}, {}
hf_repo_qrels = self.description["hf_hub_name"] + "-qrels" if "clarin-knext" in self.description["hf_hub_name"] else None
for split in kwargs.get("eval_splits", self.description["eval_splits"]):
corpus, queries, qrels = HFDataLoader(hf_repo=self.description["hf_hub_name"], hf_repo_qrels=hf_repo_qrels, streaming=False, keep_in_memory=False).load(split=split)
# Conversion from DataSet
queries = {query['id']: query['text'] for query in queries}
corpus = {doc['id']: {'title': doc['title'] , 'text': doc['text']} for doc in corpus}
self.corpus[split], self.queries[split], self.relevant_docs[split] = corpus, queries, qrels

self.data_loaded = True

def evaluate(
self,
model,
split="test",
batch_size=128,
corpus_chunk_size=None,
score_function="cos_sim",
parallel_retrieval=False,
**kwargs
):
try:
from beir.retrieval.evaluation import EvaluateRetrieval
except ImportError:
raise Exception("Retrieval tasks require beir package. Please install it with `pip install mteb[beir]`")

if not self.data_loaded:
self.load_data(parallel_retrieval=parallel_retrieval)

model = model if self.is_dres_compatible(model) else DRESModel(model)

if not parallel_retrieval:
# Non-distributed
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
model = DRES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
**kwargs,
)

else:
# Distributed (multi-GPU)
from beir.retrieval.search.dense import (
DenseRetrievalParallelExactSearch as DRPES,
)
model = DRPES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size,
**kwargs,
)



retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
retriever = RetrievalEvaluator(model, **kwargs)

scores = {}
if self.is_multilingual:
Expand All @@ -92,7 +167,7 @@ def evaluate(

def _evaluate_monolingual(self, retriever, corpus, queries, relevant_docs, lang=None, **kwargs):
start_time = time()
results = retriever.retrieve(corpus, queries)
results = retriever(corpus, queries)
end_time = time()
logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))

Expand Down Expand Up @@ -123,40 +198,3 @@ def _evaluate_monolingual(self, retriever, corpus, queries, relevant_docs, lang=
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()},
}
return scores


class DRESModel:
"""
Dense Retrieval Exact Search (DRES) in BeIR requires an encode_queries & encode_corpus method.
This class converts a MTEB model (with just an .encode method) into BeIR DRES format.
"""

def __init__(self, model, sep=" ", **kwargs):
self.model = model
self.sep = sep
self.use_sbert_model = isinstance(model, SentenceTransformer)

def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
if self.use_sbert_model:
if isinstance(self.model._first_module(), Transformer):
logger.info(f"Queries will be truncated to {self.model.get_max_seq_length()} tokens.")
elif isinstance(self.model._first_module(), WordEmbeddings):
logger.warning(
"Queries will not be truncated. This could lead to memory issues. In that case please lower the batch_size."
)
return self.model.encode(queries, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
for doc in corpus
]
return self.model.encode(sentences, batch_size=batch_size, **kwargs)
42 changes: 0 additions & 42 deletions mteb/abstasks/BeIRKOTask.py

This file was deleted.

39 changes: 0 additions & 39 deletions mteb/abstasks/BeIRPLTask.py

This file was deleted.

Loading
Loading