-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Evaluate missing splits (#1525)
* fix: evaluate missing splits (#1268) * implement partial evaluation for missing splits * lint * requested changes done from scratch * test for missing split evaluation added * uncomment test * lint * avoid circular import * use TaskResult * skip tests for now --------- Co-authored-by: Isaac Chung <chungisaac1217@gmail.com> * got test_all_splits_evaluated passing * tests passing * address review comments * make lint * handle None cases for kg_co2_emissions * use new results info --------- Co-authored-by: Thivyanth <thivyanth2004@gmail.com>
- Loading branch information
1 parent
ba09b11
commit 8e12250
Showing
3 changed files
with
248 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
|
||
from mteb import MTEB | ||
from tests.test_benchmark.mock_models import ( | ||
MockSentenceTransformer, | ||
) | ||
from tests.test_benchmark.mock_tasks import ( | ||
MockRetrievalTask, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def model(): | ||
return MockSentenceTransformer() | ||
|
||
|
||
@pytest.fixture | ||
def tasks(): | ||
return [MockRetrievalTask()] | ||
|
||
|
||
def test_all_splits_evaluated(model, tasks, tmp_path): | ||
evaluation = MTEB(tasks=tasks) | ||
results = evaluation.run( | ||
model, | ||
eval_splits=["val", "test"], | ||
output_folder=str(tmp_path / "all_splits_evaluated"), | ||
verbosity=2, | ||
) | ||
|
||
assert "MockRetrievalTask" == results[0].task_name | ||
last_evaluated_splits = evaluation.get_last_evaluated_splits() | ||
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val", "test"} | ||
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2 | ||
|
||
|
||
def test_one_missing_split(model, tasks, tmp_path): | ||
evaluation = MTEB(tasks=tasks) | ||
results = evaluation.run( | ||
model, | ||
eval_splits=["val"], | ||
output_folder=str(tmp_path / "testcase2"), | ||
verbosity=2, | ||
) | ||
|
||
assert "MockRetrievalTask" == results[0].task_name | ||
last_evaluated_splits = evaluation.get_last_evaluated_splits() | ||
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"val"} | ||
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1 | ||
|
||
results2 = evaluation.run( | ||
model, | ||
eval_splits=["val", "test"], | ||
output_folder=str(tmp_path / "testcase2"), | ||
verbosity=2, | ||
overwrite_results=True, | ||
) | ||
|
||
assert "MockRetrievalTask" == results2[0].task_name | ||
last_evaluated_splits = evaluation.get_last_evaluated_splits() | ||
assert set(last_evaluated_splits["MockRetrievalTask"]) == {"test"} | ||
assert len(last_evaluated_splits["MockRetrievalTask"]) == 1 | ||
|
||
|
||
def test_no_missing_splits(model, tasks, tmp_path): | ||
evaluation = MTEB(tasks=tasks) | ||
_ = evaluation.run( | ||
model, | ||
eval_splits=["val", "test"], | ||
output_folder=str(tmp_path / "testcase3"), | ||
verbosity=2, | ||
) | ||
|
||
last_evaluated_splits = evaluation.get_last_evaluated_splits() | ||
assert "MockRetrievalTask" in last_evaluated_splits | ||
assert len(last_evaluated_splits["MockRetrievalTask"]) == 2 | ||
|
||
evaluation = MTEB(tasks=tasks) | ||
_ = evaluation.run( | ||
model, | ||
eval_splits=["val", "test"], | ||
output_folder=str(tmp_path / "testcase3"), | ||
verbosity=2, | ||
overwrite_results=True, | ||
) | ||
|
||
last_evaluated_splits = evaluation.get_last_evaluated_splits() | ||
assert "MockRetrievalTask" in last_evaluated_splits | ||
assert len(last_evaluated_splits["MockRetrievalTask"]) == 0 |