Skip to content

Commit

Permalink
Fix CI test based on the requirements of the new merlin loader (#536)
Browse files Browse the repository at this point in the history
* cast targets to int64 dtype required by torch.scatter_

* update the dataset schema to use the same as the one defined in the model

* update exomrees46 data used in ci with hex feature converted to numerical
  • Loading branch information
sararb committed Nov 16, 2022
1 parent 11be6b1 commit 1536eb0
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion ci/test_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pip install -r requirements.txt
cd t4r_paper_repro
FEATURE_SCHEMA_PATH=../datasets_configs/ecom_rees46/rees46_schema.pbtxt
pip install gdown
gdown https://drive.google.com/uc?id=1payLwuwfa_QG6GvFVg4KT1w7dSkO-jPZ
gdown https://drive.google.com/uc?id=1NCFZ5ya3zyxPsrmupEoc9UEm4sslAddV
apt-get update -y
apt-get install unzip -y
DATA_PATH=/transformers4rec/examples/t4rec_paper_experiments/t4r_paper_repro/
Expand Down
2 changes: 1 addition & 1 deletion transformers4rec/torch/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing: float = 0.
targets = (
torch.empty(size=(targets.size(0), n_classes), device=targets.device)
.fill_(smoothing / (n_classes - 1))
.scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing)
.scatter_(1, targets.data.unsqueeze(1).to(torch.int64), 1.0 - smoothing)
)
return targets

Expand Down
1 change: 1 addition & 0 deletions transformers4rec/torch/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def __init__(
self.drop_last = drop_last

self.set_dataset(buffer_size, engine, reader_kwargs)
self.dataset.schema = self.dataset.schema.select_by_name(conts + cats + labels)

if (global_rank is not None) and (self.dataset.npartitions < global_size):
logger.warning(
Expand Down
2 changes: 1 addition & 1 deletion transformers4rec/torch/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def create_output_placeholder(scores, ks):


def tranform_label_to_onehot(labels, vocab_size):
return one_hot_1d(labels.reshape(-1), vocab_size, dtype=torch.float32).detach()
return one_hot_1d(labels.reshape(-1).to(torch.int64), vocab_size, dtype=torch.float32).detach()


def one_hot_1d(
Expand Down

0 comments on commit 1536eb0

Please sign in to comment.