Skip to content

Commit

Permalink
[MRG] Handle edge case for DAN (#271)
Browse files Browse the repository at this point in the history
* Handle edge case for DAN

* Add eps as an arg

---------

Co-authored-by: Théo Gnassounou <66993815+tgnassou@users.noreply.github.com>
  • Loading branch information
YanisLalou and tgnassou authored Oct 31, 2024
1 parent 76c67a2 commit 781ef0d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
7 changes: 5 additions & 2 deletions skada/deep/_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class DANLoss(BaseDALoss):
----------
sigmas : array-like, optional (default=None)
The sigmas for the Gaussian kernel.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
References
----------
Expand All @@ -122,9 +124,10 @@ class DANLoss(BaseDALoss):
In ICML, 2015.
"""

def __init__(self, sigmas=None):
def __init__(self, sigmas=None, eps=1e-7):
super().__init__()
self.sigmas = sigmas
self.eps = eps

def forward(
self,
Expand All @@ -137,7 +140,7 @@ def forward(
features_t,
):
"""Compute the domain adaptation loss"""
loss = dan_loss(features_s, features_t, sigmas=self.sigmas)
loss = dan_loss(features_s, features_t, sigmas=self.sigmas, eps=self.eps)
return loss


Expand Down
8 changes: 6 additions & 2 deletions skada/deep/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _maximum_mean_discrepancy(x, y, kernel):
return cost


def dan_loss(features_s, features_t, sigmas=None):
def dan_loss(features_s, features_t, sigmas=None, eps=1e-7):
"""Define the mmd loss based on multi-kernel defined in [14]_.
Parameters
Expand All @@ -162,6 +162,8 @@ def dan_loss(features_s, features_t, sigmas=None):
sigmas : array like, default=None,
If array, sigmas used for the multi gaussian kernel.
If None, uses sigmas proposed in [1]_.
eps : float, default=1e-7
Small constant added to median distance calculation for numerical stability.
Returns
-------
Expand All @@ -175,7 +177,9 @@ def dan_loss(features_s, features_t, sigmas=None):
In ICML, 2015.
"""
if sigmas is None:
median_pairwise_distance = torch.median(torch.cdist(features_s, features_s))
median_pairwise_distance = (
torch.median(torch.cdist(features_s, features_s)) + eps
)
sigmas = (
torch.tensor([2 ** (-8) * 2 ** (i * 1 / 2) for i in range(33)]).to(
features_s.device
Expand Down
19 changes: 19 additions & 0 deletions skada/deep/tests/test_deep_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
pytest.importorskip("torch")

import numpy as np
import torch

from skada.datasets import make_shifted_datasets
from skada.deep import CAN, DAN, DeepCoral
from skada.deep.losses import dan_loss
from skada.deep.modules import ToyModule2D


Expand Down Expand Up @@ -195,3 +197,20 @@ def test_can_with_custom_callbacks():
callback_classes = [cb.__class__.__name__ for cb in method.callbacks]
assert "EpochScoring" in callback_classes
assert "ComputeSourceCentroids" in callback_classes


def test_dan_loss_edge_cases():
# Create identical source features to get median distance = 0
features_s = torch.tensor([[1.0, 2.0], [1.0, 2.0]], dtype=torch.float32)
features_t = torch.tensor([[3.0, 4.0], [5.0, 6.0]], dtype=torch.float32)

# Verify median distance is 0
assert torch.median(torch.cdist(features_s, features_s)) == 0

# Test that dan_loss still works
loss = dan_loss(features_s, features_t)

# Loss should be finite and non-negative
assert not torch.isnan(loss)
assert not torch.isinf(loss)
assert loss >= 0

0 comments on commit 781ef0d

Please sign in to comment.