Skip to content

Commit

Permalink
Add ignore_idx to retrieval metrics (#676)
Browse files Browse the repository at this point in the history
* implemented feature #672
* fixed typo in tests
* changelog
* added Raises section to all retrieval metrics

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
5 people authored Dec 13, 2021
1 parent d071eb2 commit 01f88fe
Show file tree
Hide file tree
Showing 21 changed files with 396 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MinMaxMetric` to wrappers ([#556](https://github.com/PyTorchLightning/metrics/pull/556))


- Added `ignore_index` to to retrieval metrics ([#676](https://github.com/PyTorchLightning/metrics/pull/676))


### Changed

- Scalar metrics will now consistently have additional dimensions squeezed ([#622](https://github.com/PyTorchLightning/metrics/pull/622))
Expand Down
33 changes: 31 additions & 2 deletions tests/retrieval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes as _irs_mis_sz
from tests.retrieval.inputs import _input_retrieval_scores_mismatching_sizes_func as _irs_mis_sz_fn
from tests.retrieval.inputs import _input_retrieval_scores_no_target as _irs_no_tgt
from tests.retrieval.inputs import _input_retrieval_scores_with_ignore_index as _irs_ii
from tests.retrieval.inputs import _input_retrieval_scores_wrong_targets as _irs_bad_tgt

seed_all(42)
Expand Down Expand Up @@ -72,6 +73,7 @@ def _compute_sklearn_metric(
indexes: np.ndarray = None,
metric: Callable = None,
empty_target_action: str = "skip",
ignore_index: int = None,
reverse: bool = False,
**kwargs,
) -> Tensor:
Expand All @@ -90,6 +92,10 @@ def _compute_sklearn_metric(
assert isinstance(preds, np.ndarray)
assert isinstance(target, np.ndarray)

if ignore_index is not None:
valid_positions = target != ignore_index
indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions]

indexes = indexes.flatten()
preds = preds.flatten()
target = target.flatten()
Expand Down Expand Up @@ -196,6 +202,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
"`empty_target_action` received a wrong value `casual_argument`.",
dict(empty_target_action="casual_argument"),
),
# check ignore_index is valid
(
_irs.indexes,
_irs.preds,
_irs.target,
"Argument `ignore_index` must be an integer or None.",
dict(ignore_index=-100.0),
),
# check input shapes are consistent
(
_irs_mis_sz.indexes,
Expand Down Expand Up @@ -242,6 +256,14 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
"`empty_target_action` received a wrong value `casual_argument`.",
dict(empty_target_action="casual_argument"),
),
# check ignore_index is valid
(
_irs.indexes,
_irs.preds,
_irs.target,
"Argument `ignore_index` must be an integer or None.",
dict(ignore_index=-100.0),
),
# check input shapes are consistent
(
_irs_mis_sz.indexes,
Expand Down Expand Up @@ -292,6 +314,13 @@ def _concat_tests(*tests: Tuple[Dict]) -> Dict:
],
)

_default_metric_class_input_arguments_ignore_index = dict(
argnames="indexes,preds,target",
argvalues=[
(_irs_ii.indexes, _irs_ii.preds, _irs_ii.target),
],
)

_default_metric_class_input_arguments_with_non_binary_target = dict(
argnames="indexes,preds,target",
argvalues=[
Expand Down Expand Up @@ -444,7 +473,7 @@ def metric_functional_ignore_indexes(preds, target, indexes):
metric_module=metric_module,
metric_functional=metric_functional_ignore_indexes,
metric_args={"empty_target_action": "neg"},
indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted
indexes=indexes, # every additional argument will be passed to the retrieval metric and _sk_metric_adapted
)

def run_precision_test_gpu(
Expand All @@ -467,7 +496,7 @@ def metric_functional_ignore_indexes(preds, target, indexes):
metric_module=metric_module,
metric_functional=metric_functional_ignore_indexes,
metric_args={"empty_target_action": "neg"},
indexes=indexes, # every additional argument will be passed to RetrievalMAP and _sk_metric_adapted
indexes=indexes, # every additional argument will be passed to retrieval metric and _sk_metric_adapted
)

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions tests/retrieval/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@
target=torch.rand(NUM_BATCHES, 2 * BATCH_SIZE),
)

_input_retrieval_scores_with_ignore_index = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)).masked_fill(
mask=torch.randn(NUM_BATCHES, BATCH_SIZE) > 0.5, value=-100
),
)

# with errors
_input_retrieval_scores_no_target = Input(
indexes=torch.randint(high=10, size=(NUM_BATCHES, BATCH_SIZE)),
Expand Down
34 changes: 33 additions & 1 deletion tests/retrieval/test_fallout.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -55,6 +56,7 @@ class TestFallOut(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
Expand All @@ -65,9 +67,39 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalFallOut,
sk_metric=_fallout_at_k,
dist_sync_on_step=dist_sync_on_step,
reverse=True,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
33 changes: 32 additions & 1 deletion tests/retrieval/test_hit_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -52,6 +53,7 @@ class TestHitRate(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
Expand All @@ -62,9 +64,38 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalHitRate,
sk_metric=_hit_rate_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
31 changes: 30 additions & 1 deletion tests/retrieval/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
Expand All @@ -35,6 +36,7 @@ class TestMAP(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
Expand All @@ -44,8 +46,35 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
):
metric_args = {"empty_target_action": empty_target_action}
metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalMAP,
sk_metric=sk_average_precision_score,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
31 changes: 30 additions & 1 deletion tests/retrieval/test_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments,
_default_metric_class_input_arguments_ignore_index,
_default_metric_functional_input_arguments,
_errors_test_class_metric_parameters_default,
_errors_test_class_metric_parameters_no_pos_target,
Expand Down Expand Up @@ -57,6 +58,7 @@ class TestMRR(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
def test_class_metric(
self,
Expand All @@ -66,8 +68,35 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
):
metric_args = {"empty_target_action": empty_target_action}
metric_args = dict(empty_target_action=empty_target_action, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalMRR,
sk_metric=_reciprocal_rank,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
):
metric_args = dict(empty_target_action=empty_target_action, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
33 changes: 32 additions & 1 deletion tests/retrieval/test_ndcg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tests.retrieval.helpers import (
RetrievalMetricTester,
_concat_tests,
_default_metric_class_input_arguments_ignore_index,
_default_metric_class_input_arguments_with_non_binary_target,
_default_metric_functional_input_arguments_with_non_binary_target,
_errors_test_class_metric_parameters_k,
Expand Down Expand Up @@ -51,6 +52,7 @@ class TestNDCG(RetrievalMetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 3]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_with_non_binary_target)
def test_class_metric(
Expand All @@ -61,9 +63,38 @@ def test_class_metric(
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
ignore_index: int,
k: int,
):
metric_args = {"empty_target_action": empty_target_action, "k": k}
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=ignore_index)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalNormalizedDCG,
sk_metric=_ndcg_at_k,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize("empty_target_action", ["skip", "neg", "pos"])
@pytest.mark.parametrize("k", [None, 1, 4, 10])
@pytest.mark.parametrize(**_default_metric_class_input_arguments_ignore_index)
def test_class_metric_ignore_index(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
dist_sync_on_step: bool,
empty_target_action: str,
k: int,
):
metric_args = dict(empty_target_action=empty_target_action, k=k, ignore_index=-100)

self.run_class_metric_test(
ddp=ddp,
Expand Down
Loading

0 comments on commit 01f88fe

Please sign in to comment.