Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Unable to replicate evaluation metrics when using ignore_masking. #506

Closed
TassaraR opened this issue Oct 21, 2022 · 7 comments
Closed
Assignees
Labels
area/examples question Further information is requested

Comments

@TassaraR
Copy link

❓ Questions & Help

Details

transformers4rec==0.1.13

I've been trying out t4r for a while and I decided to try and replicate the evaluation metrics on my own by performing offline predictions.

I've been using a model architecture similar to one in the provided examples and similar features.

Model:

inputs = tr.TabularSequenceFeatures.from_schema(
        schema,
        max_sequence_length=36,
        aggregation='concat',
        continuous_projection=64,
        d_output=64,
        masking="mlm",
)

# Define XLNetConfig class and set default parameters for HF XLNet config  
transformer_config = tr.XLNetConfig.build(
    d_model=64, n_head=4, n_layer=2, total_seq_length=36
)
# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)

# Defines the evaluation top-N metrics and the cut-offs
metrics = [NDCGAt(top_ks=[5, 10, 20], labels_onehot=True),  
           RecallAt(top_ks=[5, 10, 20], labels_onehot=True),
           AvgPrecisionAt(top_ks=[5, 10, 20], labels_onehot=True),
           PrecisionAt(top_ks=[5, 10, 20], labels_onehot=True)
          ]

# Define a head related to next item prediction task 
head = tr.Head(
    body,
    tr.NextItemPredictionTask(weight_tying=True, hf_format=True, 
                              metrics=metrics),
    inputs=inputs,
)

# Get the end-to-end Model class 
model = tr.Model(head)

train_args = T4RecTrainingArguments(data_loader_engine= 'nvtabular', 
                                    dataloader_drop_last = False,
                                    gradient_accumulation_steps = 1,
                                    per_device_train_batch_size = 256, 
                                    per_device_eval_batch_size = 32,
                                    output_dir = "./tmp", 
                                    learning_rate=0.0005,
                                    lr_scheduler_type='cosine',
                                    num_train_epochs=5,
                                    max_sequence_length=36, 
                                    report_to = [],
                                    logging_steps=200,
                                    no_cuda=False)

trainer = Trainer(
    model=model,
    args=train_args,
    schema=schema,
    compute_metrics=True,
)

When I evaluate the metrics by using:

trainer.eval_dataset_or_path = 'eval.parquet'
train_metrics = trainer.evaluate(metric_key_prefix='eval')

I get the following recall:
eval_/next-item/recall_at_10': 0.1004810556769371


I tried writing some code to replicate the results and validate that the evaluation metrics returned by the model were correct.
(As the model was already trained I moved everything from GPU to CPU)

The following script transforms my data from pd.DataFrame to a dict with the appropriate format for the t4r model and also extracts the labels (The last products of the sequence).

For the moment I'm not removing the last item from the sessions on my dataset as I know model(data, training=False) does that for me.

# Load data
prediction_data = pd.read_parquet('eval.parquet')

# Create label
prods_arr = np.stack(prediction_data.products_padded)
last_item_idx = np.count_nonzero(prods_arr, axis=1) - 1
labels = np.array([prods_arr[n, idx] for n, idx in enumerate(last_item_idx)])

# Transform data to pytorch format   
pred_dtypes = prediction_data.applymap(lambda x: x[0]).dtypes
batch_pred = {}
for col, dtype in pred_dtypes.iteritems():
    
    if dtype == 'float64':
        tensor = np.stack(prediction_data[col]).astype(np.float32)
    else:
        tensor = np.stack(prediction_data[col])
    tensor = torch.from_numpy(tensor)
        
    batch_pred[col] = tensor.cpu()

I also created a function to evaluate the recall on my own

def recall(predicted_items: np.ndarray, real_items: np.ndarray) -> float:
    idx = 0
    recalls = np.zeros(len(predicted_items), dtype=np.float64)
    for real, pred in zip(real_items, predicted_items):

        real = real[real > 0]
        pred = pred[pred > 0]

        real_found_in_pred = np.isin(pred, real, assume_unique=True)

        if real_found_in_pred.any():
            recommended = real_found_in_pred.sum()
            recall = recommended / len(real)
        else:
            recall = 0

        recalls[idx] += recall
        idx += 1
    mean_recall = recalls.mean()
    return mean_recall

I performed offline predictions by using the following code:

predictions = model_cpu(batch_pred, training=False)['predictions']
_, topk_pred = torch.topk(predictions, k=10)
topk_pred = topk_pred.flip(dims=(1,))

After evaluating using my code

recall(topk_pred, labels)

Which returns a pretty similar result:

Recall@10
t4r trainer eval metric result: 0.1004810556769371
my own metric result:           0.1004810550781038

BUT When I try to run the predictions "manually" masking the last item of each session and using ignore_masking=True I get an entirely different result:

I re-run my script for label extraction and to adapt the pandas dataframe to a dict but this time I mask the last item of each session:

# Load data
prediction_data = pd.read_parquet('eval.parquet')

# Create label
prods_arr = np.stack(prediction_data.products_padded)
last_item_idx = np.count_nonzero(prods_arr, axis=1) - 1
labels = np.array([prods_arr[n, idx] for n, idx in enumerate(last_item_idx)])

# Performs masking
for n, idx in enumerate(last_item_idx):
    for col_nbr in range(prediction_data.shape[1]):
        arr = prediction_data.iloc[n, col_nbr].copy()
        arr[idx] = 0
        prediction_data.iloc[n, col_nbr] = arr
        
# Transform data to pytorch format   
pred_dtypes = prediction_data.applymap(lambda x: x[0]).dtypes
batch_pred = {}
for col, dtype in pred_dtypes.iteritems():
    
    if dtype == 'float64':
        tensor = np.stack(prediction_data[col]).astype(np.float32)
    else:
        tensor = np.stack(prediction_data[col])
    tensor = torch.from_numpy(tensor)
        
    batch_pred[col] = tensor.cpu()

I did check that the masking process was performed correctly

Captura de Pantalla 2022-10-21 a la(s) 17 28 05

I re-ran the inference phase with ignore_masking=True

model_results = model_cpu(batch_pred, training=False, ignore_masking=True)
predictions = model_results['predictions']
_, topk_pred = torch.topk(predictions, k=10)
topk_pred = topk_pred.flip(dims=(1,))

and got different and disappointing results

Recall@10
t4r trainer eval metric result: 0.1004810556769371
my own metric result:           0.1004810550781038
my own (ignore masking):        0.0686179125452678

I can't figure out if I'm missing something or if I did something wrong but for the moment I can't find anything on my side.

It would be of great help if someone could check this issue out and try to replicate this experiment.

Thanks in advance to anyone willing to look into this issue.

@sararb
Copy link
Contributor

sararb commented Oct 27, 2022

@TassaraR,  thank you for testing the library and investigating the consistency of the scores returned by the model!!

I was able to replicate the same results using the code you shared with the end-to-end example data. After investigating the issue, the source of the difference in scores you are observing comes from how masking is applied in T4Rec. 

In fact, we are not replacing each input feature with 0 at the corresponding masking position. Instead, we compute interaction embeddings from the input using the original information, then we replace the embeddings vector of the masked position with a trainable embeddings vector initialized in the MaskingBlock (here). This is to mimic the special character [MASK] that is used in NLP for masking, which is different from the padding index. The special embedding vector will then allow the model to distinguish between masked information and padded information. 

  • This is a detailed code showing how the T4Rec model is applying the masking: 
# Split the model into the different sub-modules

model = recsys_trainer.model.wrapper_module

input_block =  model.heads[0].body.inputs

masking_block =  model.heads[0].body.inputs.masking

transformer_block = model.heads[0].body[1]

prediction_block =  model.heads[0].prediction_task_dict['next-item']

# Get the 3-D interaction embeddings, each position is the aggregation of its related input features embeddings. 

inputs_wo_masking = input_block(batch_pred, ignore_masking=True, training=False)

# Get the boolean mask tensor related to the positions to mask (i.e. last non padded item during evaluation)

masked_positions = masking_block._compute_masked_targets(batch_pred['item_id-list_seq'], training=False)

# check that the last non-padded item is always set to True, in the masking schema

# masked_positions.schema

# The masking happens after getting the interaction embeddings by replacing the masked position with 

# a trainable embedding vector (`masking_block.masked_item_embedding`)

apply_masking_to_inputs = masking_block.apply_mask_to_inputs( inputs_wo_masking, masked_positions.schema)

# Get the hidden representation from the transformer block 

hidden_interactions = transformer_block(apply_masking_to_inputs, training=False)

# Get the final next-item prediction score

predictions = prediction_block(hidden_interactions)['predictions']

Hope that helps you understand how the masking is used in T4Rec. Can you also please test the shared code at your end to validate the scores are matching the ones returned by trainer.evaluate()?

@sararb sararb added question Further information is requested area/examples and removed status/needs-triage labels Oct 27, 2022
@sararb sararb self-assigned this Oct 27, 2022
@TassaraR
Copy link
Author

Thanks for your time and answer @sararb ! I'm going to test it out. Its going to take me a little while though!

@TassaraR
Copy link
Author

TassaraR commented Nov 2, 2022

So I tested out the code but I had to change a bit to accommodate it to my own implementation:

trainer.eval_dataset_or_path = 'short_eval.parquet'
train_metrics = trainer.evaluate(metric_key_prefix='eval')

Returns:

eval_/next-item/recall_at_10': 0.0974999 ~ 0.0975

As I mentioned before I changed my code from GPU to CPU so:

_ = model.eval()
_ = model.cpu()

I re-ran my custom code with my custom recall function:

recall(topk_pred, labels)

Returns:

0.0975

And finally I the code that you sent me @sararb :

mdl = trainer.model.wrapper_module
input_block = mdl.heads[0].body.inputs
masking_block = mdl.heads[0].body.inputs.masking
sequential_block = mdl.heads[0].body[1]
transformer_block = mdl.heads[0].body[2]
prediction_block = mdl.heads[0].prediction_task_dict['next-item']

inputs_wo_masking = input_block(batch_pred, ignore_masking=True, training=False)
masked_positions = masking_block._compute_masked_targets(batch_pred['products_padded'], training=False)
apply_masking_to_inputs = masking_block.apply_mask_to_inputs(inputs_wo_masking, masked_positions.schema)

sequential_pass = sequential_block(apply_masking_to_inputs, training=False)
transformer_pass = transformer_block(sequential_pass, training=False)

predictions = prediction_block(transformer_pass)['predictions']
_, topk_pred = torch.topk(predictions, k=10)

recall(topk_pred, labels)

Which returns the exact same value:

0.0975

Still after testing this I'm left with some questions as I feel I'm struggling with offline predictions.

For my current use-case I cannot rely on Triton, So I'm planning on building my own API to serve the model, so I'm in need of performing offline inference. The problem I notice is that using:

model(batch_pred, training=False)

always masks the last existing value of each sequence. So its trying to predict the masked value instead of the new/next item.

Given that, I would assume the correct way to perform inference should be:

model(batch_pred, training=False, ignore_masking=True)

As it should of course "ignore the mask".

As I understand while performing the prediction the last value is replaced by a special character "[MASK]" instead of zero. The problem is that in a real case scenario we won't have a [MASK] token. So I was trying to evaluate the model on a "simulated real-life scenario" by manually masking that last value with a zero and performing inference over those sequence as I already know the label.

In this case the model under-performed (As seen in the original question) by returning a score of 0.068 instead of 0.1 as shown in the first evaluation (when using ignore_masking=True).

I don't know if I'm missing or misunderstanding something but I've been struggling with this for a while. (Also I hope I made my point clear as english is my second language).

It would be awesome if some examples could be provided. I've been checking the previous issues regarding this topic but none of them have actually helped me resolve this issue.

Again, thanks a lot for your time and effort!

@conway-abacus
Copy link

Hi all, I am also a bit confused by this discrepancy. To summarize, it seems there are two ways in which one can get 'next-item' predictions given an input sequence:

  1. set ignore_masking=False and put a dummy item_id at the end of the sequence
  2. set ignore_masking=True and not alter the input item_id sequence at all

It seems that trainer.evaluate() is effectively using (1) when, in the validation set, we explicitly mask the last interaction (so we are predicting the last interaction in each session). However, some other code/PRs suggest to use ignore_masking=True in order to use the whole input sequence. So, which method should we use when generating predictions?

@sararb
Copy link
Contributor

sararb commented Nov 9, 2022

@Conway, thanks for your questions and we apologize for the confusion. The parameterignore_masking was added to separate train/eval mode from test mode.

  1. During training (ignore_masking=False): Masking is used to select and mask random positions in the input sequence. 

  2. During evaluation (ignore_masking=False): We assume the input sequence contains the label as the last interaction and masking is applied to replace it with MASK embeddings. Based on the masked input, the model will predict the last interaction. 

  3. During Test/Inference (ignore_masking=True):  We assume the input sequence does not contain the label and therefore masking should not be applied, i.e the whole sequence is passed to the model to generate predictions.  


@sararb
Copy link
Contributor

sararb commented Nov 9, 2022

@TassaraR, thank you for investigating the model’s performance in different modes. 

Regarding your point: In this case the model under-performed (As seen in the original question) by returning a score of 0.068 instead of 0.1 as shown in the first evaluation (when using ignore_masking=True).

  1. Indeed, your tests demonstrate that the model underperforms when the last element is set to zero instead of keeping the original information and replacing it with "Mask" embeddings. So There is definitely a discrepancy between using the "MASK" embedding during training and evaluation, and not using it during inference.

  2. Using your simulated code, I added a dummy interaction at the end of each input sequence and used the model with ignore_masking=False to generate predictions. The recall matches the one returned by the trainer. 

Code to add dummy interaction to manually masked sequences:
# Create label
prods_arr = np.stack(prediction_data['item_id-list_seq'])
last_item_idx = np.count_nonzero(prods_arr, axis=1) - 1
labels = np.array([prods_arr[n, idx] for n, idx in enumerate(last_item_idx)])

# Mimic real-world setting
for n, idx in enumerate(last_item_idx):
    for col_nbr in range(prediction_data.shape[1]):
        arr = prediction_data.iloc[n, col_nbr].copy()
        arr[idx] = 0
        prediction_data.iloc[n, col_nbr] = arr

# Add dummy variable to test data to account for masking
prods_arr = np.stack(prediction_data['item_id-list_seq'])
last_item_idx = np.count_nonzero(prods_arr, axis=1) - 1
for n, idx in enumerate(last_item_idx):
    for col_nbr in range(prediction_data.shape[1]):
        arr = prediction_data.iloc[n, col_nbr].copy()
        arr[idx+1] = 2
        prediction_data.iloc[n, col_nbr] = arr

# Transform data to pytorch format   
pred_dtypes = prediction_data.applymap(lambda x: x[0]).dtypes
batch_pred = {}
for col, dtype in pred_dtypes.iteritems():
    
    if dtype == 'float64':
        tensor = np.stack(prediction_data[col]).astype(np.float32)
    else:
        tensor = np.stack(prediction_data[col])
    tensor = torch.from_numpy(tensor)
        
    batch_pred[col] = tensor.cpu()
    
# Get predictions 
model_results = recsys_trainer.model.wrapper_module.cpu()(batch_pred, training=False, ignore_masking=False)
predictions = model_results['predictions']
_, topk_pred = torch.topk(predictions, k=10)
topk_pred = topk_pred.flip(dims=(1,))
print("RECALL@10 by extending test sequence with a dummy last interaction and use model with masking: ", recall(topk_pred, labels))

Thank you for pointing out this discrepancy! I opened a bug ticket to track the issue and we are currently working on a fix to ensure that the performance of the model is not affected during inference.

@TassaraR
Copy link
Author

Hi @sararb, Thanks for your answer!

First of all, it seems pretty clear that ignore_masking=True doesn't work for inference then as it makes the model under-perform.

I also had planned creating a dummy variable and performing masking (ignore_masking=False) over it to get predictions. There's just one issue:

Let's imagine an array of maximum size 5 that contains a session. In this session the user has a full cart. So it has 5 items total.

# Original session with 5 items
[456, 431, 678, 915, 231]

If we want to get recommendations we then need to create a dummy variable so that variable can be masked later. The only issue is, in a real case scenario, given this array, the only way to do that might be by shifting the array's values to the left and removing the first added item as we don't have any more space in the array to fit the dummy variable.

# shifted values of the array
[431, 678, 915, 231, 1]    ---->    [431, 678, 915, 231, <MASK>]
                     ^
                Dummy Variable

This way we are actually losing information as we had to remove the first item 456.

One approach to solve this and the one I have in mind is by extending my input size by 1 value. So instead of having a sequence of 5, we now work with a sequence of 6 but the last spot will always be a 0 (so each session should always be padded by at least one 0 at the end). We train our model with this extended sequence.

# input size 6 where the last value is always 0
[456, 431, 678, 915, 231, 0] --->  [456, 431, 678, 915, 231, 1] 
                                                             ^
                                                   Dummy variable to mask

This value can actually be replaced by the dummy variable to mask and thus we don't have to drop any of our current existing values.

I don't know though if this may create some issues at inference as the model was never trained with a session full of 6 items. (I hope it doesn't cause any problems).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area/examples question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants