-
Notifications
You must be signed in to change notification settings - Fork 148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented sampled softmax for NextItemPredictionTask #671
Conversation
Documentation previewhttps://nvidia-merlin.github.io/Transformers4Rec/review/pr-671 |
670821b
to
0c1be06
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks good to me. I just left some remarks/questions to understand the code base.
else: | ||
logits = self.output_layer(inputs) | ||
logits = inputs @ output_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might need to keep the bias parameter self.output_layer_bias
: logits = inputs @ output_weights + self.output_layer_bias
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the bias because it think would not be available if ANN is used later for serving. Does that make sense?
I can do some benchmark later to see if bias helps to improve accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, that makes sense! Otherwise, we'll need to save the output_bias
vector in addition to the pre-trained candidate embeddings..
|
||
return predictions | ||
logits = torch.cat([positive_scores, negative_scores], axis=1) | ||
new_targets = torch.zeros(logits.shape[0], dtype=torch.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first element of each row should be 1
instead of 0
to account for the positive target, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The targets are the ids sparse representation, not the one-hot representation. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, so the new_targets
is a 1-D vector that contains the index of the positive item in the logits
tensor (which is always corresponding to index 0
)
def forward(self, inputs): | ||
return self.module(inputs) | ||
def forward(self, inputs, **kwargs): | ||
return self.module(inputs, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we need the extra **kwargs
here?
" [`sum`, `none`, `mean`]" | ||
) | ||
return loss | ||
return torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction=reduction, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's great to see that label_smoothing was added in the latest version of CrossEntropyLoss!
y = labels_all | ||
x, y = self.pre(x, targets=y, training=training, testing=testing) # type: ignore | ||
|
||
loss = self.loss(x, y) | ||
return { | ||
"loss": loss, | ||
"labels": labels_all, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that self.pre invokes the next-item task using the sampled softmax option, which returns logits x related to the list of [positive_item, sampled negatives]. So I wonder how these logits are connected to labels_all (which is a tensor of positive item ids) for metrics calculation.
dist = self.unique_sampling_dist | ||
else: | ||
dist = self.dist | ||
dist = dist.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to move the definition of dist to the class constructor, to avoid copying the tensor to the GPU/CPU device multiple times?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The challenge is how to get the device in the constructor. Any ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use register_buffer
to register the variable dist
. Then, the method model.to(device)
will ensure that the buffer is copied to the right device. It is something like:
- in the constructor, you set:
self.register_buffer('dist', dist)
- in the method
sampled
: you can just call the registered bufferself.dist
so we use `torch.multinomial(..., replacement=True).unique()` which doesn't guarantee | ||
the same number of unique sampled items. You can try to increase | ||
n_samples_multiplier_before_unique to increase the chances to have more | ||
unique samples in that case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 !! Thank you for creating this class. It was very helpful for learning how to approximate item frequency distributions for both sampling with and without repetition!
…label_smoothing). Fixed test
… adjusted min const value used to fix sampling accidental hits to work properly with fp16. Ensures targets are torch.long, otherwise losses raise an error. Turning metrics top_ks as lists rather than tensors
…stributions as a buffer, so that they are automatically assigned to the right device and also serialized correctly
c59f278
to
5c282d3
Compare
Goals ⚽
Implements sampled softmax for NextItemPredictionTask. It allows for faster training and evaluation.
Implementation Details 🚧
NextItemPredictionTask
to have a standard output layer op (a dot product) whenweight_tying
is both enabled or not.sampled_softmax
option toNextItemPredictionTask
LogUniformSampler
that is able to return probabilities for both unique_sampling = True or FalseNextItemPredictionTask
LabelSmoothCrossEntropyLoss
to be just an alias oftorch.nn.CrossEntropyLoss(label_smoothing=...)
, as PyTorch has addedlabel_smoothing
in one of its last versions. Added aDeprecationWarning
toLabelSmoothCrossEntropyLoss
Testing Details 🔍
test_with_next_item_pred_sampled_softmax
Benchmark 🔍
I have performed a benchmark of sampled softmax in different configurations (weight tying enabled and disabled and with different # samples) to understand the impact of sampled softmax in training throughtput and accuracy.
Setup
The experiments were performed using the T4Rec paper reproducibility script, which was changed to accept new CLI args
--sampled_softmax
and--sampled_softmax_max_n_samples
, and the REES46 preprocessed dataset.The benchmark was done using Merlin PyTorch 23.02 container, with manual update of the
core
,dataloader
andmodels
folders to pull and install their latest version from GitHub.Command line
The script performs incremental training and evaluation. I use the first five days for training and evaluation is computed for each next day. Here is the base command line with the utilized hparams.
The hparams that are changed for the experiments are
--mf_constrained_embeddings
(enables weight-tying if provided, i.e., reusing the item id embedding table as output layer),--sampled_softmax
(enables sampled softmax if provided) and--sampled_softmax_max_n_samples
(number of negative samples).cd /transformers4rec/examples/t4rec_paper_experiments/t4r_paper_repro CUDA_VISIBLE_DEVICES=0 python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path "../datasets_configs/ecom_rees46/rees46_schema.pbtxt" --fp16 --data_loader_engine merlin --start_time_window_index 1 --final_time_window_index 6 --time_window_folder_pad_digits 4 --model_type albert --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only --layer_norm_featurewise --mlm --num_hidden_groups -1 --inner_group_num 1 --per_device_train_batch_size 512 --learning_rate 0.0004904752786458524 --dropout 0.0 --input_dropout 0.1 --weight_decay 9.565968888623912e-05 --d_model 320 --item_embedding_dim 320 --n_layer 2 --n_head 8 --stochastic_shared_embeddings_replacement_prob 0.06 --item_id_embeddings_init_std 0.11 --other_embeddings_init_std 0.025 --mlm_probability 0.6000000000000001 --eval_on_test_set --seed 100 --report_to none --label_smoothing 0.2 --mf_constrained_embeddings --sampled_softmax --sampled_softmax_max_n_samples 1000
Results
The results can be seen in the following table. Steps/sec represents the throughtput and Recall and NDCG are accuracy top-k metrics.
The gist is that it is possible to get both a better training throughput with a gain of accuracy by using sampled softmax.
Some notes from this results:
Disclaimer: These experiments were not hypertuned for every configuration. Furthermore, accuracy results might differ a lot with different runs in particular when smaller number of samples (e.g. 1k) are used.