Skip to content

Commit

Permalink
feat: Evaluate missing splits (#1525)
Browse files Browse the repository at this point in the history
* fix: evaluate missing splits (#1268)

* implement partial evaluation for missing splits

* lint

* requested changes done from scratch

* test for missing split evaluation added

* uncomment test

* lint

* avoid circular import

* use TaskResult

* skip tests for now

---------

Co-authored-by: Isaac Chung <chungisaac1217@gmail.com>

* got test_all_splits_evaluated passing

* tests passing

* address review comments

* make lint

* handle None cases for kg_co2_emissions

* use new results info

---------

Co-authored-by: Thivyanth <thivyanth2004@gmail.com>
  • Loading branch information
isaac-chung and thivyanth authored Nov 29, 2024
1 parent ba09b11 commit 8e12250
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 26 deletions.
145 changes: 123 additions & 22 deletions mteb/evaluation/MTEB.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import traceback
from collections.abc import Iterable
from copy import copy
from copy import copy, deepcopy
from datetime import datetime
from itertools import chain
from pathlib import Path
Expand All @@ -15,6 +15,7 @@
import datasets
from sentence_transformers import CrossEncoder, SentenceTransformer

from mteb.abstasks.AbsTask import ScoresDict
from mteb.encoder_interface import Encoder
from mteb.model_meta import ModelMeta
from mteb.models import model_meta_from_sentence_transformers
Expand Down Expand Up @@ -84,6 +85,8 @@ def __init__(
self._version = version
self.err_logs_path = err_logs_path

self.last_evaluated_splits = {}

self.select_tasks(**kwargs)

def deprecation_warning(
Expand Down Expand Up @@ -307,6 +310,70 @@ def _run_eval(
tock = time()
return results, tick, tock

@staticmethod
def _get_missing_splits(
existing_results: TaskResult | None, task_eval_splits: list[str]
) -> list[str]:
if existing_results is None:
return task_eval_splits

missing_splits = []
for split in task_eval_splits:
if split not in existing_results.scores:
missing_splits.append(split)
elif not existing_results.scores[
split
]: # Check if the split has any scores
missing_splits.append(split)

return missing_splits

@staticmethod
def _merge_results(
existing_results: TaskResult, new_results: TaskResult
) -> TaskResult:
merged_scores = existing_results.scores.copy()

for split, scores in new_results.scores.items():
if split in merged_scores:
merged_scores[split] = MTEB._merge_split_scores(
merged_scores[split], scores
)
else:
merged_scores[split] = scores

existing_kg_co2_emissions = (
existing_results.kg_co2_emissions
if existing_results.kg_co2_emissions
else 0
)
new_kg_co2_emissions = (
new_results.kg_co2_emissions if new_results.kg_co2_emissions else 0
)
merged_kg_co2_emissions = None
if existing_kg_co2_emissions and new_kg_co2_emissions:
merged_kg_co2_emissions = existing_kg_co2_emissions + new_kg_co2_emissions
merged_results = TaskResult(
dataset_revision=new_results.dataset_revision,
task_name=new_results.task_name,
mteb_version=new_results.mteb_version,
scores=merged_scores,
evaluation_time=existing_results.evaluation_time
+ new_results.evaluation_time,
kg_co2_emissions=merged_kg_co2_emissions,
)

return merged_results

@staticmethod
def _merge_split_scores(
existing_scores: list[ScoresDict], new_scores: list[ScoresDict]
) -> list[ScoresDict]:
merged = {score["hf_subset"]: score for score in existing_scores}
for score in new_scores:
merged[score["hf_subset"]] = score
return list(merged.values())

def run(
self,
model: SentenceTransformer | Encoder,
Expand Down Expand Up @@ -378,38 +445,62 @@ def run(
original_tasks = (
self.tasks.copy()
) # save them in case we re-use the object (e.g. for reranking)

# To evaluate missing splits, we keep track of the task name and the corresponding splits.
self.last_evaluated_splits = {}

while len(self.tasks) > 0:
task = self.tasks[0]
logger.info(
f"\n\n********************** Evaluating {task.metadata.name} **********************"
)

# skip evaluation if results folder exists and overwrite_results is False
if output_path:
save_path = output_path / f"{task.metadata.name}{task.save_suffix}.json"
if save_path.exists() and not overwrite_results:
logger.info(
f"{task.metadata.name} results already exists. Loading results from disk. Set overwrite_results=True to overwrite."
)
mteb_results = TaskResult.from_disk(save_path)
evaluation_results.append(mteb_results)
del self.tasks[0] # empty memory
continue
try:
existing_results = None
if save_path.exists():
existing_results = TaskResult.from_disk(save_path)

if not overwrite_results:
logger.info(
f"{task.metadata.name} results already exists. Loading results from disk. Set overwrite_results=True to overwrite."
)
evaluation_results.append(existing_results)
del self.tasks[0] # empty memory
continue

task_eval_splits = (
eval_splits if eval_splits is not None else task.eval_splits
)
missing_splits = self._get_missing_splits(
existing_results, task_eval_splits
)

if not missing_splits and existing_results:
evaluation_results.append(existing_results)

# no splits are evaluated.
self.last_evaluated_splits[task.metadata.name] = []
del self.tasks[0]
continue

if missing_splits:
logger.info(
f"Running evaluation for missing splits: {missing_splits}"
)

# load data
logger.info(f"Loading dataset for {task.metadata_dict['name']}")
try:
task.check_if_dataset_is_superseeded()
task.load_data(eval_splits=task_eval_splits, **kwargs)

# run evaluation
task_results = {}
evaluation_time = 0
kg_co2_emissions: int | None = 0 if co2_tracker else None
for split in task_eval_splits:

self.last_evaluated_splits[task.metadata.name] = []

for split in missing_splits:
if co2_tracker:
try:
from codecarbon import EmissionsTracker
Expand Down Expand Up @@ -443,6 +534,8 @@ def run(
**kwargs,
)

self.last_evaluated_splits[task.metadata.name].append(split)

logger.info(
f"Evaluation for {task.metadata_dict['name']} on {split} took {tock - tick:.2f} seconds"
)
Expand All @@ -452,21 +545,22 @@ def run(
if verbosity >= 1:
logger.info(f"Scores: {results}")

mteb_task_result = TaskResult.from_task_results(
new_results = TaskResult.from_task_results(
task,
task_results,
evaluation_time=evaluation_time,
kg_co2_emissions=kg_co2_emissions,
)

# save results
if existing_results:
merged_results = self._merge_results(existing_results, new_results)
else:
merged_results = new_results

if output_path:
with open(save_path, "w") as f_out:
json.dump(
mteb_task_result.to_dict(), f_out, indent=2, sort_keys=True
)
merged_results.to_disk(save_path)

evaluation_results.append(mteb_task_result)
evaluation_results.append(merged_results)

except Exception as e:
logger.error(
Expand All @@ -485,7 +579,6 @@ def run(
# empty memory
del self.tasks[0]

# restore original tasks
self.tasks = original_tasks
return evaluation_results

Expand Down Expand Up @@ -536,3 +629,11 @@ def _save_model_metadata(model_meta: ModelMeta, output_folder: Path) -> None:

with save_path.open("w") as f:
json.dump(model_meta.to_dict(), f)

def get_last_evaluated_splits(self):
"""Returns a dictionary of tasks and their evaluated splits from the most recent run.
Tasks with empty lists indicate that results already existed and no splits were evaluated.
"""
return deepcopy(
{task: list(splits) for task, splits in self.last_evaluated_splits.items()}
)
38 changes: 34 additions & 4 deletions tests/test_benchmark/mock_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,35 +1284,65 @@ class MockRetrievalTask(AbsTaskRetrieval):
"average_relevant_docs_per_query": 2.0,
"max_relevant_docs_per_query": 2,
"unique_relevant_docs": 2,
}
},
"val": {
"number_of_characters": 112,
"num_samples": 4,
"num_queries": 2,
"num_documents": 2,
"min_document_length": 23,
"average_document_length": 26.0,
"max_document_length": 29,
"unique_documents": 2,
"min_query_length": 27,
"average_query_length": 30.0,
"max_query_length": 33,
"unique_queries": 2,
"min_relevant_docs_per_query": 2,
"average_relevant_docs_per_query": 2.0,
"max_relevant_docs_per_query": 2,
"unique_relevant_docs": 2,
},
}

metadata = TaskMetadata(
type="Retrieval",
name="MockRetrievalTask",
main_score="ndcg_at_10",
**general_args, # type: ignore
**dict(general_args | {"eval_splits": ["val", "test"]}), # type: ignore
)

def load_data(self, **kwargs):
self.queries = {
"test": {
"q1": "This is a test sentence",
"q2": "This is another test sentence",
}
},
"val": {
"q1": "This is a test sentence",
"q2": "This is another test sentence",
},
}
self.corpus = {
"test": {
"d1": "This is a positive sentence",
"d2": "This is another positive sentence",
}
},
"val": {
"d1": "This is a positive sentence",
"d2": "This is another positive sentence",
},
}

self.relevant_docs = {
"test": {
"q1": {"d1": 1, "d2": 0},
"q2": {"d1": 0, "d2": 1},
},
"val": {
"q1": {"d1": 1, "d2": 0},
"q2": {"d1": 0, "d2": 1},
},
}
self.data_loaded = True

Expand Down
91 changes: 91 additions & 0 deletions tests/test_evaluation/test_split_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import pytest

from mteb import MTEB
from tests.test_benchmark.mock_models import (
MockSentenceTransformer,
)
from tests.test_benchmark.mock_tasks import (
MockRetrievalTask,
)


@pytest.fixture
def model():
return MockSentenceTransformer()


@pytest.fixture
def tasks():
return [MockRetrievalTask()]


def test_all_splits_evaluated(model, tasks, tmp_path):
evaluation = MTEB(tasks=tasks)
results = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "all_splits_evaluated"),
verbosity=2,
)

assert "MockRetrievalTask" == results[0].task_name
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val", "test"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2


def test_one_missing_split(model, tasks, tmp_path):
evaluation = MTEB(tasks=tasks)
results = evaluation.run(
model,
eval_splits=["val"],
output_folder=str(tmp_path / "testcase2"),
verbosity=2,
)

assert "MockRetrievalTask" == results[0].task_name
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1

results2 = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "testcase2"),
verbosity=2,
overwrite_results=True,
)

assert "MockRetrievalTask" == results2[0].task_name
last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"test"}
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1


def test_no_missing_splits(model, tasks, tmp_path):
evaluation = MTEB(tasks=tasks)
_ = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "testcase3"),
verbosity=2,
)

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2

evaluation = MTEB(tasks=tasks)
_ = evaluation.run(
model,
eval_splits=["val", "test"],
output_folder=str(tmp_path / "testcase3"),
verbosity=2,
overwrite_results=True,
)

last_evaluated_splits = evaluation.get_last_evaluated_splits()
assert "MockRetrievalTask" in last_evaluated_splits
assert len(last_evaluated_splits["MockRetrievalTask"]) == 0

0 comments on commit 8e12250

Please sign in to comment.