Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix early stopping code in dfencoder to use average loss of batches in validation set #908

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 60 additions & 38 deletions morpheus/models/dfencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,14 @@ def _init_binary(self, df=None):

def _init_features(self, df=None):
"""Initializea the features of different types.
`df` is required if any of `preset_cats`, `preset_numerical_scaler_params`, and `binary_feature_list` are not provided
`df` is required if any of `preset_cats`, `preset_numerical_scaler_params`, and `binary_feature_list` are not provided
at model initialization.

Parameters
----------
df : pandas.DataFrame, optional
dataframe used to compute and extract feature information, by default None

Raises
------
ValueError
Expand Down Expand Up @@ -590,7 +590,7 @@ def preprocess_data(
Whether to process the df into an input tensor without swapping and include it in the returned data dict.
Note. Training required only the swapped input tensor while validation can use both.
include_swapped_input_by_feature_type : bool
Whether to process the swapped df into num/bin/cat feature tensors and include them in the returned data dict.
Whether to process the swapped df into num/bin/cat feature tensors and include them in the returned data dict.
This is useful for baseline performance evaluation for validation.

Returns
Expand Down Expand Up @@ -855,10 +855,6 @@ def _fit_centralized(self, df, epochs=1, val=None, run_validation=False, use_val
baseline = self.compute_baseline_performance(val_in, val_df)
LOG.debug(msg)

val_batches = len(val_df) // self.eval_batch_size
if len(val_df) % self.eval_batch_size != 0:
val_batches += 1

n_updates = len(df) // self.batch_size
if len(df) % self.batch_size > 0:
n_updates += 1
Expand All @@ -883,28 +879,10 @@ def _fit_centralized(self, df, epochs=1, val=None, run_validation=False, use_val
if run_validation and val is not None:
self.eval()
with torch.no_grad():
swapped_loss = []
id_loss = []
for i in range(val_batches):
start = i * self.eval_batch_size
stop = (i + 1) * self.eval_batch_size

slc_in = val_in.iloc[start:stop]
slc_in_tensor = self.build_input_tensor(slc_in)

slc_out = val_df.iloc[start:stop]
slc_out_tensor = self.build_input_tensor(slc_out)

num, bin, cat = self.model(slc_in_tensor)
_, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out)
swapped_loss.append(net_loss)

num, bin, cat = self.model(slc_out_tensor)
_, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out, _id=True)
id_loss.append(net_loss)
mean_id_loss, mean_swapped_loss = self._validate_dataframe(orig_df=val_df, swapped_df=val_in)

# Early stopping
current_net_loss = net_loss
current_net_loss = mean_id_loss
LOG.debug('The Current Net Loss: %s', current_net_loss)

if current_net_loss > last_loss:
Expand All @@ -924,16 +902,13 @@ def _fit_centralized(self, df, epochs=1, val=None, run_validation=False, use_val
self.logger.end_epoch()

if self.verbose:
swapped_loss = np.array(swapped_loss).mean()
id_loss = np.array(id_loss).mean()

msg = '\n'
msg += 'net validation loss, swapped input: \n'
msg += f"{round(swapped_loss, 4)} \n\n"
msg += f"{round(mean_swapped_loss, 4)} \n\n"
msg += 'baseline validation loss: '
msg += f"{round(baseline, 4)} \n\n"
msg += 'net validation loss, unaltered input: \n'
msg += f"{round(id_loss, 4)} \n\n\n"
msg += f"{round(mean_id_loss, 4)} \n\n\n"
LOG.debug(msg)

#Getting training loss statistics
Expand All @@ -949,6 +924,53 @@ def _fit_centralized(self, df, epochs=1, val=None, run_validation=False, use_val
i_loss = cce_loss[:, i]
self.feature_loss_stats[ft] = self._create_stat_dict(i_loss)

def _validate_dataframe(self, orig_df, swapped_df):
"""Runs a validation loop on the given validation pandas DataFrame, computing and returning the average loss of
both the original input and the input with swapped values.

Parameters
----------
orig_df : pandas.DataFrame, the original validation data
swapped_df: pandas.DataFrame, the swapped validation data

Returns
-------
Tuple[float]
A tuple containing two floats:
- mean_id_loss: the average net loss when passing the original df through the model
- mean_swapped_loss: the average net loss when passing the swapped df through the model
"""
val_batches = len(orig_df) // self.eval_batch_size
if len(orig_df) % self.eval_batch_size != 0:
val_batches += 1

swapped_loss = []
id_loss = []
for i in range(val_batches):
start = i * self.eval_batch_size
stop = (i + 1) * self.eval_batch_size

# calculate the loss of the swapped tensor compared to the target (original) tensor
slc_in = swapped_df.iloc[start:stop]
slc_in_tensor = self.build_input_tensor(slc_in)

slc_out = orig_df.iloc[start:stop]
slc_out_tensor = self.build_input_tensor(slc_out)

num, bin, cat = self.model(slc_in_tensor)
_, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out)
swapped_loss.append(net_loss)

# calculate the loss of the original tensor
num, bin, cat = self.model(slc_out_tensor)
_, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out, _id=True)
id_loss.append(net_loss)

mean_swapped_loss = np.array(swapped_loss).mean()
mean_id_loss = np.array(id_loss).mean()

return mean_id_loss, mean_swapped_loss

def _fit_distributed(
self,
train_data,
Expand All @@ -960,12 +982,12 @@ def _fit_distributed(
use_val_for_loss_stats=False,
):
"""Fit the model in the distributed fashion with early stopping based on validation loss.
If run_validation is True, the val_dataset will be used for validation during training and early stopping
If run_validation is True, the val_dataset will be used for validation during training and early stopping
will be applied based on patience argument.

Parameters
----------
train_data : pandas.DataFrame or torch.utils.data.Dataset or torch.utils.data.DataLoader
train_data : pandas.DataFrame or torch.utils.data.Dataset or torch.utils.data.DataLoader
data object of training data
rank : int
the rank of the current process
Expand All @@ -980,7 +1002,7 @@ def _fit_distributed(
use_val_for_loss_stats : bool, optional
whether to populate loss stats in the main process (rank 0) for z-score calculation using the validation set.
If set to False, loss stats would be populated using the train_dataloader, which can be slow due to data size.
By default False, but using the validation set to populate loss stats is strongly recommended (for both efficiency
By default False, but using the validation set to populate loss stats is strongly recommended (for both efficiency
and model efficacy).

Raises
Expand Down Expand Up @@ -1096,7 +1118,7 @@ def _fit_distributed(
self._populate_loss_stats_from_dataset(dataset_for_loss_stats)

def _fit_batch(self, input_swapped, num_target, bin_target, cat_target, **kwargs):
"""Forward pass on the input_swapped, then computes the losses from the predicted outputs and actual targets, performs
"""Forward pass on the input_swapped, then computes the losses from the predicted outputs and actual targets, performs
backpropagation, updates the model parameters, and returns the net loss.

Parameters
Expand All @@ -1108,7 +1130,7 @@ def _fit_batch(self, input_swapped, num_target, bin_target, cat_target, **kwargs
bin_target : torch.Tensor
tensor of shape (batch_size, binary feature count) with binary targets
cat_target : List[torch.Tensor]
list of size (categorical feature count), each entry is a 1-d tensor of shape (batch_size) containing the categorical
list of size (categorical feature count), each entry is a 1-d tensor of shape (batch_size) containing the categorical
targets

Returns
Expand Down
91 changes: 79 additions & 12 deletions tests/dfencoder/test_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import os
import typing
from unittest.mock import patch

import pandas as pd
import pytest
Expand Down Expand Up @@ -59,18 +60,20 @@ def train_ae():
"""
Construct an AutoEncoder instance with the same values used by `train_ae_stage`
"""
yield autoencoder.AutoEncoder(encoder_layers=[512, 500],
decoder_layers=[512],
activation='relu',
swap_p=0.2,
lr=0.01,
lr_decay=.99,
batch_size=512,
verbose=False,
optimizer='sgd',
scaler='standard',
min_cats=1,
progress_bar=False)
yield autoencoder.AutoEncoder(
encoder_layers=[512, 500],
decoder_layers=[512],
activation='relu',
swap_p=0.2,
lr=0.01,
lr_decay=.99,
batch_size=512,
verbose=False,
optimizer='sgd',
scaler='standard',
min_cats=1,
progress_bar=False,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -280,6 +283,70 @@ def test_auto_encoder_fit(train_ae: autoencoder.AutoEncoder, train_df: pd.DataFr
train_ae.optim is train_ae.lr_decay.optimizer


def test_auto_encoder_fit_early_stopping(train_df: pd.DataFrame):
train_data = train_df.sample(frac=0.7, random_state=1)
validation_data = train_df.drop(train_data.index)

epochs = 10

# Test normal training loop with no early stopping
ae = autoencoder.AutoEncoder(patience=5)
ae.fit(train_data, val_data=validation_data, run_validation=True, use_val_for_loss_stats=True, epochs=epochs)
# assert that training runs through all epoches
assert ae.logger.n_epochs == epochs

class MockHelper:
"""A helper class for mocking the `_validate_dataframe` method in the `AutoEncoder` class."""

def __init__(self, orig_losses, swapped_loss=1.0):
""" Initialization.

Parameters:
-----------
orig_losses : list
A list of original validation losses to be returned by the mocked `_validate_dataframe` method.
swapped_loss : float, optional (default=1.0)
A fixed loss value to be returned by the mocked `_validate_dataframe` method as the `swapped_loss`.
Fixed as it's unrelated to the early-stopping functionality being tested here.
"""
self.swapped_loss = swapped_loss
self.orig_losses = orig_losses
# counter to keep track of the number of times the mocked `_validate_dataframe` method has been called
self.count = 0

def mocked_validate_dataframe(self, *args, **kwargs):
"""
A mocked version of the `_validate_dataframe` method in the `AutoEncoder` class for testing early stopping.

Returns:
--------
tuple of (float, float)
A tuple of original validation loss and swapped loss values for each epoch.
"""
orig_loss = self.orig_losses[self.count]
self.count += 1
return orig_loss, self.swapped_loss

# Test early stopping
orig_losses = [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]

ae = autoencoder.AutoEncoder(
patience=3) # should stop at epoch 3 as the first 3 losses are monotonically increasing
mock_helper = MockHelper(orig_losses=orig_losses) # validation loss is strictly increasing
with patch.object(ae, '_validate_dataframe', side_effect=mock_helper.mocked_validate_dataframe):
ae.fit(train_data, val_data=validation_data, run_validation=True, use_val_for_loss_stats=True, epochs=epochs)
# assert that training early-stops at epoch 3
assert ae.logger.n_epochs == 3

ae = autoencoder.AutoEncoder(
patience=5) # should stop at epoch 9 as losses from epoch 5-9 are monotonically increasing
mock_helper = MockHelper(orig_losses=orig_losses) # validation loss is strictly increasing
with patch.object(ae, '_validate_dataframe', side_effect=mock_helper.mocked_validate_dataframe):
ae.fit(train_data, val_data=validation_data, run_validation=True, use_val_for_loss_stats=True, epochs=epochs)
# assert that training early-stops at epoch 3
assert ae.logger.n_epochs == 9


@pytest.mark.usefixtures("manual_seed")
def test_auto_encoder_get_anomaly_score(train_ae: autoencoder.AutoEncoder, train_df: pd.DataFrame):
train_ae.fit(train_df, epochs=1)
Expand Down