From 22be4fd731e8d767a82772c6ae3c01c7d28c6c0b Mon Sep 17 00:00:00 2001 From: rnyak Date: Tue, 8 Nov 2022 13:02:17 -0800 Subject: [PATCH] fix typo and metric error (#522) --- tests/torch/test_ranking_metrics.py | 4 ++-- transformers4rec/torch/ranking_metric.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/torch/test_ranking_metrics.py b/tests/torch/test_ranking_metrics.py index f59c47dc58..9e0a558ab2 100644 --- a/tests/torch/test_ranking_metrics.py +++ b/tests/torch/test_ranking_metrics.py @@ -17,7 +17,7 @@ import pytest import torch -from transformers4rec.torch.ranking_metric import MeanRecipricolRankAt +from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt tr = pytest.importorskip("transformers4rec.torch") @@ -47,7 +47,7 @@ def test_score_with_transform_onehot(torch_ranking_metrics_inputs, metric): def test_mean_recipricol_rank(): - metric = MeanRecipricolRankAt() + metric = MeanReciprocalRankAt() metric.top_ks = [1, 2, 3, 4] metric.labels_onehot = False result = metric( diff --git a/transformers4rec/torch/ranking_metric.py b/transformers4rec/torch/ranking_metric.py index 94bae9b98f..9c3dcf07f4 100644 --- a/transformers4rec/torch/ranking_metric.py +++ b/transformers4rec/torch/ranking_metric.py @@ -280,9 +280,9 @@ def _metric( @ranking_metrics_registry.register_with_multiple_names("mrr_at", "mrr") -class MeanRecipricolRankAt(RankingMetric): +class MeanReciprocalRankAt(RankingMetric): def __init__(self, top_ks=None, labels_onehot=False): - super(MeanRecipricolRankAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) + super(MeanReciprocalRankAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot) def _metric( self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, log_base: int = 2 @@ -310,6 +310,8 @@ def _metric( device=scores.device, dtype=torch.float32 ) for index, k in enumerate(ks): - values, _ = (topk_labels[:, :k] / (torch.arange(k) + 1)).max(dim=1) + values, _ = (topk_labels[:, :k] / (torch.arange(k) + 1).to(device=scores.device)).max( + dim=1 + ) results[:, index] = values return results