Skip to content

Commit

Permalink
[TO_REVIEW] Add epsilon in MCC to prevent log(0) (#270)
Browse files Browse the repository at this point in the history
* Add epsilon in MCC to prevent log(0)

* Add eps as an arg
  • Loading branch information
YanisLalou authored Oct 31, 2024
1 parent 425d301 commit 76c67a2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 4 deletions.
6 changes: 5 additions & 1 deletion skada/deep/_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class MCCLoss(BaseDALoss):
T : float, default=1
Temperature parameter for the scaling.
If T=1, the scaling is a softmax function.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
References
----------
Expand All @@ -33,9 +35,10 @@ class MCCLoss(BaseDALoss):
In ECCV, 2020.
"""

def __init__(self, T=1):
def __init__(self, T=1, eps=1e-7):
super().__init__()
self.T = T
self.eps = eps

def forward(
self,
Expand All @@ -51,6 +54,7 @@ def forward(
loss = mcc_loss(
y_pred_t,
T=self.T,
eps=self.eps,
)
return loss

Expand Down
7 changes: 4 additions & 3 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,17 +368,18 @@ def probability_scaling(logits, temperature=1):
return torch.nn.functional.softmax(logits / temperature, dim=1)


def mcc_loss(y, T=1):
def mcc_loss(y, T=1, eps=1e-7):
"""Estimate the Frobenius norm divide by 4*n**2
for DeepCORAL method [33]_.
Parameters
----------
y : tensor
The output of target domain of the model.
T : float, default=1
The temperature for the scaling.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
Returns
-------
Expand All @@ -395,7 +396,7 @@ def mcc_loss(y, T=1):
y_scaled = probability_scaling(y, temperature=T)

# Uncertainty Reweighting & class correlation matrix
H = -torch.sum(y_scaled * torch.log(y_scaled), axis=1)
H = -torch.sum(y_scaled * torch.log(y_scaled + eps), axis=1)
W = (1 + torch.exp(-H)) / torch.mean(1 + torch.exp(-H))
y_weighted = torch.matmul(torch.diag(W), y_scaled)
C = torch.einsum("ij,ik->jk", y_scaled, y_weighted)
Expand Down
25 changes: 25 additions & 0 deletions skada/deep/tests/test_deep_class_confusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
pytest.importorskip("torch")

import numpy as np
import torch

from skada.datasets import make_shifted_datasets
from skada.deep import MCC
from skada.deep.losses import mcc_loss
from skada.deep.modules import ToyModule2D


Expand Down Expand Up @@ -54,3 +56,26 @@ def test_mcc(T):
history = method.history_

assert history[0]["train_loss"] > history[-1]["train_loss"]


def test_mcc_with_zeros():
"""Test that mcc_loss handles zero probabilities correctly."""
# Create logits with extreme values that will result in zeros
# after softmax operation due to numerical underflow
logits = torch.tensor(
[
[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0],
]
)

# Verify that we actually get zeros in y_scaled
y_scaled = torch.nn.functional.softmax(logits, dim=1)
assert torch.sum(y_scaled == 0.0) > 0, "Test setup failed: no zeros in y_scaled"

# This should not raise any errors due to the epsilon in log
loss = mcc_loss(logits, T=1.0)

assert torch.isfinite(loss) # Check that the loss is not NaN or infinite
assert loss >= 0 # MCC loss should be non-negative

0 comments on commit 76c67a2

Please sign in to comment.