Skip to content

Commit

Permalink
fix loss calculation to ensure model convergence
Browse files Browse the repository at this point in the history
  • Loading branch information
hsin-c committed Jan 10, 2024
1 parent 2e3f4e4 commit 01517bb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
14 changes: 11 additions & 3 deletions morpheus/models/dfencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,10 +698,18 @@ def compute_loss_from_targets(self, num, bin, cat, num_target, bin_target, cat_t

def do_backward(self, mse, bce, cce):
# running `backward()` seperately on mse/bce/cce is equivalent to summing them up and run `backward()` once
loss_fn = mse + bce
loss = 0

if len(self.numeric_fts) > 0:
loss += mse

if len(self.binary_fts) > 0:
loss += bce

for ls in cce:
loss_fn += ls
loss_fn.backward()
loss += ls

loss.backward()

def compute_baseline_performance(self, in_, out_):
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/dfencoder/test_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unittest.mock import patch

import pandas as pd
import numpy as np
import pytest
import torch
from torch.utils.data import Dataset as TorchDataset
Expand Down Expand Up @@ -483,3 +484,23 @@ def test_auto_encoder_get_results(train_ae: autoencoder.AutoEncoder, train_df: p
# Numpy float has different precision checks than python float, so we wrap it.
assert round(float(results.loc[0, 'mean_abs_z']), 3) == 0.335
assert results.loc[0, 'z_loss_scaler_type'] == 'z'


@pytest.mark.usefixtures("manual_seed")
def test_auto_encoder_num_only_convergence(train_ae: autoencoder.AutoEncoder):
num_df = pd.DataFrame(
{
'num_feat_1': [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9],
'num_feat_2': [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1],
}
)

train_ae.fit(num_df, epochs=50)

avg_loss = np.sum(
[np.array(loss[1]) for loss in train_ae.logger.train_fts.values()], axis=0
) / len(train_ae.logger.train_fts)

# Make sure the model converges with numerical feats only
assert avg_loss[-1] < avg_loss[0] / 2

0 comments on commit 01517bb

Please sign in to comment.