Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix discrepancy between evaluation and inference modes #551

Merged
merged 14 commits into from
Nov 30, 2022
2 changes: 1 addition & 1 deletion ci/test_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train
python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path $FEATURE_SCHEMA_PATH --fp16 --data_loader_engine nvtabular --start_time_window_index 1 --final_time_window_index 2 --time_window_folder_pad_digits 4 --model_type xlnet --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only --mf_constrained_embeddings --layer_norm_featurewise --attn_type bi --plm --per_device_train_batch_size 128 --learning_rate 0.0003387925502203725 --dropout 0.0 --input_dropout 0.2 --weight_decay 2.1769664191492473e-05 --d_model 384 --item_embedding_dim 384 --n_layer 4 --n_head 16 --label_smoothing 0.7000000000000001 --stochastic_shared_embeddings_replacement_prob 0.02 --item_id_embeddings_init_std 0.13 --other_embeddings_init_std 0.005 --plm_probability 0.5 --plm_max_span_length 3 --eval_on_test_set --seed 100 --report_to none

### XLNet (MLM) - Item Id feature
python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path $FEATURE_SCHEMA_PATH --fp16 --data_loader_engine nvtabular --start_time_window_index 1 --final_time_window_index 2 --time_window_folder_pad_digits 4 --model_type xlnet --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only --mf_constrained_embeddings --layer_norm_featurewise --attn_type bi --mlm --per_device_train_batch_size 128 --learning_rate 0.0006667377132554976 --dropout 0.0 --input_dropout 0.1 --weight_decay 3.910060265627374e-05 --d_model 192 --item_embedding_dim 448 --n_layer 3 --n_head 16 --label_smoothing 0.0 --stochastic_shared_embeddings_replacement_prob 0.1 --item_id_embeddings_init_std 0.11 --other_embeddings_init_std 0.02 --mlm_probability 0.30000000000000004 --eval_on_test_set --seed 100 --max_steps 20 --report_to none
python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path $FEATURE_SCHEMA_PATH --fp16 --data_loader_engine nvtabular --start_time_window_index 1 --final_time_window_index 2 --time_window_folder_pad_digits 4 --model_type xlnet --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only --mf_constrained_embeddings --layer_norm_featurewise --attn_type bi --mlm --per_device_train_batch_size 128 --learning_rate 0.0006667377132554976 --dropout 0.0 --input_dropout 0.1 --weight_decay 3.910060265627374e-05 --d_model 192 --item_embedding_dim 448 --n_layer 3 --n_head 16 --label_smoothing 0.0 --stochastic_shared_embeddings_replacement_prob 0.1 --item_id_embeddings_init_std 0.11 --other_embeddings_init_std 0.02 --mlm_probability 0.30000000000000004 --eval_on_test_set --seed 100 --report_to none

### XLNET (MLM) - CONCAT + SOFT ONE-HOT ENCODING - All features
python3 transf_exp_main.py --output_dir ./tmp/ --overwrite_output_dir --do_train --do_eval --validate_every 10 --logging_steps 20 --save_steps 0 --data_path $DATA_PATH --features_schema_path $FEATURE_SCHEMA_PATH --fp16 --data_loader_engine nvtabular --start_time_window_index 1 --final_time_window_index 2 --time_window_folder_pad_digits 4 --model_type xlnet --loss_type cross_entropy --per_device_eval_batch_size 128 --similarity_type concat_mlp --tf_out_activation tanh --inp_merge mlp --learning_rate_warmup_steps 0 --learning_rate_schedule linear_with_warmup --hidden_act gelu --num_train_epochs 5 --dataloader_drop_last --compute_metrics_each_n_steps 1 --session_seq_length_max 20 --eval_on_last_item_seq_only --mf_constrained_embeddings --layer_norm_featurewise --attn_type bi --mlm --input_features_aggregation concat --per_device_train_batch_size 128 --learning_rate 0.00034029107417129616 --dropout 0.0 --input_dropout 0.1 --weight_decay 3.168336235732841e-05 --d_model 448 --item_embedding_dim 384 --n_layer 2 --n_head 8 --label_smoothing 0.6000000000000001 --stochastic_shared_embeddings_replacement_prob 0.0 --item_id_embeddings_init_std 0.06999999999999999 --other_embeddings_init_std 0.085 --mlm_probability 0.30000000000000004 --embedding_dim_from_cardinality_multiplier 1.0 --numeric_features_project_to_embedding_dim 20 --numeric_features_soft_one_hot_encoding_num_embeddings 5 --eval_on_test_set --seed 100 --use_side_information_features --report_to none
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os
from functools import partial

import numpy as np
import pandas as pd
import torch
import transformers
Expand All @@ -29,13 +30,15 @@
log_metric_results,
log_parameters,
)
from merlin.io import Dataset
from transf_exp_args import DataArguments, ModelArguments, TrainingArguments
from transformers import HfArgumentParser, set_seed
from transformers.trainer_utils import is_main_process

import transformers4rec.torch as t4r
from merlin_standard_lib import Schema, Tag
from transformers4rec.torch import Trainer
from transformers4rec.torch.utils.data_utils import MerlinDataLoader
from transformers4rec.torch.utils.examples_utils import wipe_memory

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -182,6 +185,63 @@ def main():

log_aot_metric_results(training_args.output_dir, results_avg_time)

# Mimic the inference by manually computing recall@10 using the evaluation data
# of the last time-index.
eval_path = os.path.join(
data_args.data_path,
str(
data_args.final_time_window_index,
).zfill(data_args.time_window_folder_pad_digits),
"test.parquet" if training_args.eval_on_test_set else "valid.parquet",
)
prediction_data = pd.read_parquet(eval_path)
# Extract label
labels = prediction_data["sess_pid_seq"].apply(lambda x: x[-1]).values

# Truncate input sequences up to last item - 1 to mimic the inference
def mask_last_interaction(x):
return list(x[:-1])

list_columns = schema.select_by_tag("list").column_names
for col in list_columns:
prediction_data[col] = prediction_data[col].apply(mask_last_interaction)
# Get top-10 predictions
test_loader = MerlinDataLoader.from_schema(
schema,
Dataset(prediction_data),
training_args.per_device_eval_batch_size,
max_sequence_length=training_args.max_sequence_length,
shuffle=False,
)
trainer.test_dataloader = test_loader
topk_preds = trainer.predict(test_loader).predictions[0]
# Compute recall@10
recall_10 = recall(topk_preds, labels)
logger.info(f"Recall@10 of manually masked test data = {str(recall_10)}")
output_file = os.path.join(training_args.output_dir, "eval_results_over_time.txt")
with open(output_file, "a") as writer:
writer.write(f"\n***** Recall@10 of simulated inference = {recall_10} *****\n")
# Verify that the recall@10 from train.evaluate() matches the recall@10 calculated manually
if not isinstance(input_module.masking, t4r.masking.PermutationLanguageModeling):
# TODO fix inference discrepancy for permutation language modeling
assert np.isclose(recall_10, results_over_time[2]["eval_/next-item/recall_at_10"], rtol=0.1)


def recall(predicted_items: np.ndarray, real_items: np.ndarray) -> float:
bs, top_k = predicted_items.shape
valid_rows = real_items != 0

# reshape predictions and labels to compare
# the top-10 predicted item-ids with the label id.
real_items = real_items.reshape(bs, 1, -1)
predicted_items = predicted_items.reshape(bs, 1, top_k)

num_relevant = real_items.shape[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to filter num_relevant[valid_rows] as you do with predicted_correct_sum[valid_rows]

predicted_correct_sum = (predicted_items == real_items).sum(-1)
predicted_correct_sum = predicted_correct_sum[valid_rows]
recall_per_row = predicted_correct_sum / num_relevant
return np.mean(recall_per_row)


def incremental_train_eval(
trainer, start_time_index, end_time_index, input_dir, training_args, data_args
Expand Down
6 changes: 3 additions & 3 deletions examples/tutorial/03-Session-based-recsys.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@
"metadata": {},
"outputs": [],
"source": [
"# import NVTabular dependencies\n",
"from transformers4rec.torch.utils.data_utils import NVTabularDataLoader\n",
"# import Merlin Data Loader dependencies\n",
"from transformers4rec.torch.utils.data_utils import MerlinDataLoader\n",
"\n",
"x_cat_names, x_cont_names = ['product_id-list_seq'], []\n",
"\n",
Expand All @@ -413,7 +413,7 @@
"# Define a `get_dataloader` function to call in the training loop\n",
"def get_dataloader(path, batch_size=32):\n",
"\n",
" return NVTabularDataLoader.from_schema(\n",
" return MerlinDataLoader.from_schema(\n",
" schema,\n",
" path, \n",
" batch_size,\n",
Expand Down
20 changes: 15 additions & 5 deletions tests/torch/features/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_sequential_tabular_features_with_masking(yoochoose_schema, torch_yoocho
def test_sequential_tabular_features_ignore_masking(yoochoose_schema, torch_yoochoose_like):
import numpy as np

from transformers4rec.torch.masking import CausalLanguageModeling
from transformers4rec.torch.masking import CausalLanguageModeling, MaskedLanguageModeling

input_module = tr.TabularSequenceFeatures.from_schema(
yoochoose_schema,
Expand All @@ -118,15 +118,25 @@ def test_sequential_tabular_features_ignore_masking(yoochoose_schema, torch_yooc

input_module._masking = CausalLanguageModeling(hidden_size=100)

output_ignore_masking = (
output_inference_masking = (
input_module(torch_yoochoose_like, training=False, testing=False).detach().cpu().numpy()
)
output_masking = (
output_clm_masking = (
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
)

assert np.allclose(output_wo_masking, output_ignore_masking, rtol=1e-04, atol=1e-08)
assert not np.allclose(output_wo_masking, output_masking, rtol=1e-04, atol=1e-08)
assert np.allclose(output_wo_masking, output_inference_masking, rtol=1e-04, atol=1e-08)
assert not np.allclose(output_wo_masking, output_clm_masking, rtol=1e-04, atol=1e-08)

input_module._masking = MaskedLanguageModeling(hidden_size=100)
output_inference_masking = (
input_module(torch_yoochoose_like, training=False, testing=False).detach().cpu().numpy()
)
output_eval_masking = (
input_module(torch_yoochoose_like, training=False, testing=True).detach().cpu().numpy()
)
# MLM extends the inputs with one position during inference
assert output_inference_masking.shape[1] == output_eval_masking.shape[1] + 1


def test_tabular_features_yoochoose_direct(yoochoose_schema, torch_yoochoose_like):
Expand Down
4 changes: 3 additions & 1 deletion tests/torch/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def test_mask_all_next_item_for_eval(torch_masking_inputs, task):
padding_idx=torch_masking_inputs["padding_idx"],
eval_on_last_item_seq_only=False,
)
masking_info = lm.compute_masked_targets(torch_masking_inputs["labels"], training=False)
masking_info = lm.compute_masked_targets(
torch_masking_inputs["labels"], training=False, testing=True
)
# get the labels from output
trgt_pad = masking_info.targets != torch_masking_inputs["padding_idx"]
labels = masking_info.targets[trgt_pad].flatten().numpy()
Expand Down
14 changes: 14 additions & 0 deletions transformers4rec/config/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def build(
axial_pos_shape_first_dim=4,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
attention_head_size=d_model,
Expand Down Expand Up @@ -166,6 +168,8 @@ def build(
log_attention_weights=False,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
num_hidden_layers=n_layer,
Expand Down Expand Up @@ -199,6 +203,8 @@ def build(
log_attention_weights=False,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
embedding_size=d_model,
Expand Down Expand Up @@ -234,6 +240,8 @@ def build(
log_attention_weights=False,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
num_attention_heads=n_head,
Expand Down Expand Up @@ -306,6 +314,8 @@ def build(
log_attention_weights=False,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
num_hidden_layers=n_layer,
Expand All @@ -316,6 +326,7 @@ def build(
dropout=dropout,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
max_position_embeddings=total_seq_length,
vocab_size=1,
**kwargs,
)
Expand All @@ -338,6 +349,8 @@ def build(
log_attention_weights=False,
**kwargs
):
# To account for target positions at inference mode, we extend the maximum sequence length.
total_seq_length = total_seq_length + 2
return cls(
hidden_size=d_model,
num_hidden_layers=n_layer,
Expand All @@ -348,6 +361,7 @@ def build(
dropout=dropout,
pad_token_id=pad_token,
output_attentions=log_attention_weights,
max_position_embeddings=total_seq_length,
vocab_size=1,
**kwargs,
)
Expand Down
7 changes: 5 additions & 2 deletions transformers4rec/torch/features/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,12 @@ def forward(self, inputs, training=False, testing=False, **kwargs):
if self.projection_module:
outputs = self.projection_module(outputs)

if self.masking and (training or testing):
if self.masking:
outputs = self.masking(
outputs, item_ids=self.to_merge["categorical_module"].item_seq, training=training
outputs,
item_ids=self.to_merge["categorical_module"].item_seq,
training=training,
testing=testing,
)

return outputs
Expand Down
Loading