Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ContrastiveOutput to support sequential encoders #1086

Merged
merged 2 commits into from
May 12, 2023

Conversation

sararb
Copy link
Contributor

@sararb sararb commented May 9, 2023

Goals ⚽

This PR adds support for negative sampling to the ContrastiveOutput class for session-based models where the query encoder returns a 3-D ragged tensor.

Implementation Details 🚧

  • Flatten the values of the query embeddings to match with the sampled negative embeddings.
  • Reconstruct the ragged representation using the mask information of the input query.
  • Apply the same transformation to positive candidates (in case they are sequential).

Testing Details 🔍

  • Add a unit test for defining a transformer-based model as a retrieval model and training it with sampled softmax.

Benchmark 🔍

I used the session-based script (implemented here)) to perform a benchmark of sampled softmax in various configurations, similar to the study conducted in T4Rec (available here)).

Commandline
I use the first four days for training and evaluation is computed for the fifth day. Here is the base command line with the utilized hparam:

python3 session_based.py --metrics_log_frequency 20 --train_path /models/examples/session_based_script/ecomrees_five_days/train --eval_path /models/examples/session_based_script/ecomrees_five_days/valid --schema_path /models/examples/session_based_script/ecomrees_five_days --task multi_class_classification --embedding_dim 448  --d_model 192 --n_layer 3 --n_head 16 --label_smoothing 0.0  --model_type xlnet --eval_batch_size 128 --train_batch_size 128 --epochs 5 --weight_tying --xlnet_attn_type bi --training_task masked --evaluation_task last --masking_probability 0.30000000000000004 --lr 0.0006667377132554976 --transformer_dropout 0.0  --log_to_wandb --transformer_activation gelu --feature_normalization --input_dropout 0.1 --optimizer adamw --weight_decay 3.910060265627374e-05 --save_topk_predictions --emb_init_std 0.11 --sampled_softmax --num_negatives 1000 --logq_correction

The hparams that are changed for the experiments are --sampled_softmax (enables sampled softmax if provided), --logq_correction, and --num_negatives (number of negative samples).

Results
The results can be seen in the following table. Average examples/sec represents the throughtput and Recall and NDCG are accuracy top-k metrics.

image

@sararb sararb added this to the Merlin 23.05 milestone May 9, 2023
@sararb sararb self-assigned this May 9, 2023
@github-actions
Copy link

github-actions bot commented May 9, 2023

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1086

Copy link
Member

@gabrielspmoreira gabrielspmoreira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me

if is_ragged:
logits.copy_with_updates(
outputs=original_query_embedding.with_flat_values(logits.outputs),
targets=original_target.with_flat_values(logits.targets),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This with_flat_values is very useful! Seems faster than rebuilding the full ragged tensor.

@gabrielspmoreira gabrielspmoreira merged commit 5f82b55 into main May 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants