diff --git a/transformers4rec/torch/masking.py b/transformers4rec/torch/masking.py index 4cf813694a..33342c55df 100644 --- a/transformers4rec/torch/masking.py +++ b/transformers4rec/torch/masking.py @@ -259,7 +259,9 @@ def _compute_masked_targets(self, item_ids: torch.Tensor, training=False) -> Mas labels.size(0), dtype=torch.long, device=item_ids.device # type: ignore ) last_item_sessions = mask_labels.sum(dim=1) - 1 - label_seq_trg_eval = torch.zeros(labels.shape, dtype=torch.long, device=item_ids.device) + label_seq_trg_eval = torch.zeros( + labels.shape, dtype=labels.dtype, device=item_ids.device + ) label_seq_trg_eval[rows_ids, last_item_sessions] = labels[rows_ids, last_item_sessions] # Updating labels and mask labels = label_seq_trg_eval