Skip to content

Commit

Permalink
fix typo and metric error (#522)
Browse files Browse the repository at this point in the history
  • Loading branch information
rnyak committed Nov 8, 2022
1 parent 0adb10c commit 22be4fd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/torch/test_ranking_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch

from transformers4rec.torch.ranking_metric import MeanRecipricolRankAt
from transformers4rec.torch.ranking_metric import MeanReciprocalRankAt

tr = pytest.importorskip("transformers4rec.torch")

Expand Down Expand Up @@ -47,7 +47,7 @@ def test_score_with_transform_onehot(torch_ranking_metrics_inputs, metric):


def test_mean_recipricol_rank():
metric = MeanRecipricolRankAt()
metric = MeanReciprocalRankAt()
metric.top_ks = [1, 2, 3, 4]
metric.labels_onehot = False
result = metric(
Expand Down
8 changes: 5 additions & 3 deletions transformers4rec/torch/ranking_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ def _metric(


@ranking_metrics_registry.register_with_multiple_names("mrr_at", "mrr")
class MeanRecipricolRankAt(RankingMetric):
class MeanReciprocalRankAt(RankingMetric):
def __init__(self, top_ks=None, labels_onehot=False):
super(MeanRecipricolRankAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot)
super(MeanReciprocalRankAt, self).__init__(top_ks=top_ks, labels_onehot=labels_onehot)

def _metric(
self, ks: torch.Tensor, scores: torch.Tensor, labels: torch.Tensor, log_base: int = 2
Expand Down Expand Up @@ -310,6 +310,8 @@ def _metric(
device=scores.device, dtype=torch.float32
)
for index, k in enumerate(ks):
values, _ = (topk_labels[:, :k] / (torch.arange(k) + 1)).max(dim=1)
values, _ = (topk_labels[:, :k] / (torch.arange(k) + 1).to(device=scores.device)).max(
dim=1
)
results[:, index] = values
return results

0 comments on commit 22be4fd

Please sign in to comment.