Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Multi-node Evaluation #132

Merged
merged 11 commits into from
Aug 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 26 additions & 71 deletions mteb/abstasks/AbsTaskRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from time import time
from typing import Dict, List

import torch.multiprocessing as mp
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, WordEmbeddings
import os

from .AbsTask import AbsTask

Expand Down Expand Up @@ -46,7 +47,6 @@ def evaluate(
split="test",
batch_size=128,
corpus_chunk_size=None,
target_devices=None,
score_function="cos_sim",
**kwargs
):
Expand All @@ -60,11 +60,21 @@ def evaluate(

corpus, queries, relevant_docs = self.corpus[split], self.queries[split], self.relevant_docs[split]

try:
raise ImportError("MTEB is temporarily incompatible with HFDataLoader")
if os.getenv("RANK", None) is None:
# Non-distributed
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

model = model if self.is_dres_compatible(model, is_parallel=False) else DRESModel(model)

if self.description["beir_name"].startswith("cqadupstack"):
raise ImportError("CQADupstack is incompatible with latest BEIR")
model = DRES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
**kwargs,
)

else:
# Distributed (multi-GPU)
from beir.retrieval.search.dense import (
DenseRetrievalParallelExactSearch as DRPES,
)
Expand All @@ -74,35 +84,19 @@ def evaluate(
model = DRPES(
model,
batch_size=batch_size,
target_devices=target_devices,
corpus_chunk_size=corpus_chunk_size,
**kwargs,
)
except ImportError:
if target_devices is not None:
logger.warning(
"DenseRetrievalParallelExactSearch could not be imported from beir. Using DenseRetrievalExactSearch instead."
)
logger.warning("The parameter target_devices is ignored.")

from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

model = model if self.is_dres_compatible(model, is_parallel=False) else DRESModel(model)

model = DRES(
model,
batch_size=batch_size,
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
**kwargs,
)

retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
start_time = time()
results = retriever.retrieve(corpus, queries)
end_time = time()
logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))

ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values)
ndcg, _map, recall, precision = retriever.evaluate(relevant_docs, results, retriever.k_values, ignore_identical_ids=kwargs.get("ignore_identical_ids", True))
mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr")

scores = {
Expand All @@ -125,32 +119,16 @@ class DRESModel:
def __init__(self, model, sep=" ", **kwargs):
self.model = model
self.sep = sep

def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[str, object]:
logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))

ctx = mp.get_context("spawn")
input_queue = ctx.Queue()
output_queue = ctx.Queue()
processes = []

for process_id, device_name in enumerate(target_devices):
p = ctx.Process(
target=SentenceTransformer._encode_multi_process_worker,
args=(process_id, device_name, self.model, input_queue, output_queue),
daemon=True,
)
p.start()
processes.append(p)

return {"input": input_queue, "output": output_queue, "processes": processes}

def stop_multi_process_pool(self, pool: Dict[str, object]):
output_queue = pool["output"]
[output_queue.get() for _ in range(len(pool["processes"]))]
return self.model.stop_multi_process_pool(pool)
self.use_sbert_model = isinstance(model, SentenceTransformer)

def encode_queries(self, queries: List[str], batch_size: int, **kwargs):
if self.use_sbert_model:
if isinstance(self.model._first_module(), Transformer):
logger.info(f"Queries will be truncated to {self.model.get_max_seq_length()} tokens.")
elif isinstance(self.model._first_module(), WordEmbeddings):
logger.warning(
"Queries will not be truncated. This could lead to memory issues. In that case please lower the batch_size."
)
return self.model.encode(queries, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs):
Expand All @@ -166,27 +144,4 @@ def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs)
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
for doc in corpus
]
return self.model.encode(sentences, batch_size=batch_size, **kwargs)

def encode_corpus_parallel(
self, corpus: List[Dict[str, str]], pool: Dict[str, object], batch_size: int, chunk_id: int, **kwargs
):
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
for doc in corpus
]

if chunk_id is not None and chunk_id >= len(pool["processes"]):
output_queue = pool["output"]
output_queue.get()

input_queue = pool["input"]
input_queue.put([chunk_id, batch_size, sentences])
return self.model.encode(sentences, batch_size=batch_size, **kwargs)
27 changes: 15 additions & 12 deletions mteb/abstasks/BeIRTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import datasets

from .AbsTask import AbsTask
import logging

logger = logging.getLogger(__name__)

class BeIRTask(AbsTask):
def __init__(self, **kwargs):
Expand All @@ -18,17 +20,18 @@ def load_data(self, eval_splits=None, **kwargs):
except ImportError:
raise Exception("Retrieval tasks require beir package. Please install it with `pip install mteb[beir]`")

USE_BEIR_DEVELOPMENT = False
try:
raise ImportError("MTEB is temporarily incompatible with HFDataLoader")
USE_HF_DATASETS = False

# TODO @nouamane: move non-distributed to `HFDataLoader`
if os.getenv("RANK", None) is not None:
if self.description["beir_name"].startswith("cqadupstack"):
raise ImportError("CQADupstack is incompatible with latest BEIR")
from beir.datasets.data_loader_hf import HFDataLoader as BeirDataLoader

USE_BEIR_DEVELOPMENT = True
except ImportError:
from beir.datasets.data_loader import GenericDataLoader as BeirDataLoader
raise ImportError("CQADupstack is incompatible with BEIR's HFDataLoader in a distributed setting")
from beir.datasets.data_loader_hf import HFDataLoader
logger.info("Using HFDataLoader for BeIR")
USE_HF_DATASETS = True
else:
from beir.datasets.data_loader import GenericDataLoader
logger.info("Using GenericDataLoader for BeIR")

if self.data_loaded:
return
Expand All @@ -39,16 +42,16 @@ def load_data(self, eval_splits=None, **kwargs):

self.corpus, self.queries, self.relevant_docs = {}, {}, {}
for split in eval_splits:
if USE_BEIR_DEVELOPMENT:
self.corpus[split], self.queries[split], self.relevant_docs[split] = BeirDataLoader(
if USE_HF_DATASETS:
self.corpus[split], self.queries[split], self.relevant_docs[split] = HFDataLoader(
hf_repo=f"BeIR/{dataset}"
).load(split=split)
else:
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
download_path = os.path.join(datasets.config.HF_DATASETS_CACHE, "BeIR")
data_path = util.download_and_unzip(url, download_path)
data_path = f"{data_path}/{sub_dataset}" if sub_dataset else data_path
self.corpus[split], self.queries[split], self.relevant_docs[split] = BeirDataLoader(
self.corpus[split], self.queries[split], self.relevant_docs[split] = GenericDataLoader(
data_folder=data_path
).load(split=split)
self.data_loaded = True
5 changes: 5 additions & 0 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
else:
self._task_types = task_types
self._task_categories = task_categories
self._tasks = None

self._task_langs = task_langs if task_langs is not None else []
if type(self._task_langs) is str:
Expand Down Expand Up @@ -101,6 +102,10 @@ def _extend_lang_pairs(self):
return

def _display_tasks(self, task_list, name=None):
# disable logging for other ranks
if int(os.getenv("RANK", 0)) != 0:
return

console = Console()
if name:
console.rule(f"[bold]{name}\n", style="grey15")
Expand Down
3 changes: 1 addition & 2 deletions mteb/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging



def _get_library_name() -> str:
return __name__.split(".")[0]

Expand All @@ -25,4 +24,4 @@ def enable_explicit_format() -> None:

for handler in handlers:
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
handler.setFormatter(formatter)
handler.setFormatter(formatter)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def description(self):
"type": "Classification",
"category": "s2s",
"eval_splits": ["test", "validation"],
"eval_langs": ["nb"], # assumed to be bokmål
"eval_langs": ["nb"], # assumed to be bokmål
"main_score": "accuracy",
"n_experiments": 10,
"samples_per_label": 16,
Expand Down
2 changes: 0 additions & 2 deletions mteb/tasks/Classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


from .AmazonCounterfactualClassification import *
from .AmazonPolarityClassification import *
from .AmazonReviewsClassification import *
Expand Down
88 changes: 88 additions & 0 deletions scripts/retrieval.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/bin/bash
#SBATCH --job-name=mteb-retrieval
#SBATCH --nodes=4
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=96
#SBATCH --gres=gpu:8
#SBATCH --mem-per-cpu=11G # This is essentially 1.1T / 96
#SBATCH --exclusive
#SBATCH --partition=production-cluster
#SBATCH -o /fsx/nouamane/logs/%x-%j-train.out

set -x -e

# source ~/.bashrc
# TODO Replace with your env name
# conda activate 2-0-cu-117

echo "START TIME: $(date)"

MTEB_REPO=/fsx/nouamane/projects/mteb

pushd $MTEB_REPO

GPUS_PER_NODE=8
NNODES=$SLURM_NNODES

# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000

export USE_DIST=1

CMD=" \
scripts/retrieval_multigpu.py
"

export LAUNCHER="python -u -m torch.distributed.run \
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
--rdzv_backend c10d \
--max_restarts 0 \
--tee 3 \
"

echo $CMD

# hide duplicated errors using this hack - will be properly fixed in pt-1.12
# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1
# export NCCL_ALGO=Ring

# AWS specific
export NCCL_PROTO=simple
export RDMAV_FORK_SAFE=1
export FI_EFA_FORK_SAFE=1
export FI_EFA_USE_DEVICE_RDMA=1
export FI_PROVIDER=efa
export FI_LOG_LEVEL=1
export NCCL_IB_DISABLE=1
export NCCL_SOCKET_IFNAME=ens

export CUDA_HOME=/usr/local/cuda-11.7
# This is needed for torch1.12.1 otherwise it doesn't link correctly, not sur what the issue was.
# export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so # warning: doesn't work with torch 2.0
export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH

# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
# SRUN_ARGS=" \
# --wait=60 \
# --kill-on-bad-exit=1 \
# "

srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD"

# srun doesn't work if this script was called with `srun` itself
# bash -c "$LAUNCHER $CMD"

echo "END TIME: $(date)"
32 changes: 32 additions & 0 deletions scripts/retrieval_multigpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging

from sentence_transformers import SentenceTransformer

from mteb import MTEB
import torch.distributed as dist
import torch
import os


# To run this script on multiple GPUs, you need to install the following branch of BEIR
# pip install git+https://github.com/NouamaneTazi/beir@nouamane/better-multi-gpu

# Then use this command to run on 2 GPUs for example
# torchrun --nproc_per_node=2 scripts/retrieval_multigpu.py

if __name__ == "__main__":
dist.init_process_group("nccl")
device_id = int(os.getenv("LOCAL_RANK", 0))
torch.cuda.set_device(torch.cuda.device(device_id))

# Enable logging only first rank=0
rank = int(os.getenv("RANK", 0))
if rank != 0:
logging.basicConfig(level=logging.WARN)
else:
logging.basicConfig(level=logging.INFO)

model = SentenceTransformer("intfloat/e5-large", device="cuda")
# eval = MTEB(tasks=["MSMARCO"])
eval = MTEB(task_types=["Retrieval"])
eval.run(model, batch_size=1024, overwrite_results=True, eval_splits=["test"])