diff --git a/ci/test_integration.sh b/ci/test_integration.sh index 98c6626fbf..adce2aab23 100755 --- a/ci/test_integration.sh +++ b/ci/test_integration.sh @@ -35,7 +35,7 @@ pip install -r requirements.txt cd t4r_paper_repro FEATURE_SCHEMA_PATH=../datasets_configs/ecom_rees46/rees46_schema.pbtxt pip install gdown -gdown https://drive.google.com/uc?id=1payLwuwfa_QG6GvFVg4KT1w7dSkO-jPZ +gdown https://drive.google.com/uc?id=1NCFZ5ya3zyxPsrmupEoc9UEm4sslAddV apt-get update -y apt-get install unzip -y DATA_PATH=/transformers4rec/examples/t4rec_paper_experiments/t4r_paper_repro/ diff --git a/transformers4rec/torch/losses.py b/transformers4rec/torch/losses.py index 99ad23efa7..aca9f388f7 100644 --- a/transformers4rec/torch/losses.py +++ b/transformers4rec/torch/losses.py @@ -31,7 +31,7 @@ def _smooth_one_hot(targets: torch.Tensor, n_classes: int, smoothing: float = 0. targets = ( torch.empty(size=(targets.size(0), n_classes), device=targets.device) .fill_(smoothing / (n_classes - 1)) - .scatter_(1, targets.data.unsqueeze(1), 1.0 - smoothing) + .scatter_(1, targets.data.unsqueeze(1).to(torch.int64), 1.0 - smoothing) ) return targets diff --git a/transformers4rec/torch/utils/data_utils.py b/transformers4rec/torch/utils/data_utils.py index 19421a5e2f..17790e9f42 100644 --- a/transformers4rec/torch/utils/data_utils.py +++ b/transformers4rec/torch/utils/data_utils.py @@ -225,6 +225,7 @@ def __init__( self.drop_last = drop_last self.set_dataset(buffer_size, engine, reader_kwargs) + self.dataset.schema = self.dataset.schema.select_by_name(conts + cats + labels) if (global_rank is not None) and (self.dataset.npartitions < global_size): logger.warning( diff --git a/transformers4rec/torch/utils/torch_utils.py b/transformers4rec/torch/utils/torch_utils.py index 86676e3814..a8821d0aa4 100644 --- a/transformers4rec/torch/utils/torch_utils.py +++ b/transformers4rec/torch/utils/torch_utils.py @@ -199,7 +199,7 @@ def create_output_placeholder(scores, ks): def tranform_label_to_onehot(labels, vocab_size): - return one_hot_1d(labels.reshape(-1), vocab_size, dtype=torch.float32).detach() + return one_hot_1d(labels.reshape(-1).to(torch.int64), vocab_size, dtype=torch.float32).detach() def one_hot_1d(