Skip to content

Commit

Permalink
7263 add diffusion loss (#7272)
Browse files Browse the repository at this point in the history
Fixes #7263.

### Description

Add diffusion loss. I also made a [demo
notebook](https://github.com/kvttt/deep-atlas/blob/main/diffusion_loss_scale_test.ipynb)
to provide some explanations and analyses of diffusion loss.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: kaibo <ktang@unc.edu>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
kvttt and KumoLiu authored Dec 5, 2023
1 parent 88f8dd2 commit d5585c3
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ Registration Losses
.. autoclass:: BendingEnergyLoss
:members:

`DiffusionLoss`
~~~~~~~~~~~~~~~
.. autoclass:: DiffusionLoss
:members:

`LocalNormalizedCrossCorrelationLoss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNormalizedCrossCorrelationLoss
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
from .deform import BendingEnergyLoss, DiffusionLoss
from .dice import (
Dice,
DiceCELoss,
Expand Down
82 changes: 82 additions & 0 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,85 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return energy


class DiffusionLoss(_Loss):
"""
Calculate the diffusion based on first-order differentiation of pred using central finite difference.
For the original paper, please refer to
VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.
Adapted from:
VoxelMorph (https://github.com/voxelmorph/voxelmorph)
"""

def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:
"""
Args:
normalize:
Whether to divide out spatial sizes in order to make the computation roughly
invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.normalize = normalize

def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Args:
pred:
Predicted dense displacement field (DDF) with shape BCH[WD],
where C is the number of spatial dimensions.
Note that diffusion loss can only be calculated
when the sizes of the DDF along all spatial dimensions are greater than 2.
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``pred`` is not 3-d, 4-d or 5-d.
ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2.
ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions.
"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 2:
raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, "
f"does not match number of spatial dimensions, {pred.ndim - 2}"
)

# first order gradient
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

# spatial dimensions in a shape suited for broadcasting below
if self.normalize:
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))

diffusion = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
if self.normalize:
# We divide the partial derivative for each vector component at each voxel by the spatial size
# corresponding to that component relative to the spatial size of the vector component with respect
# to which the partial derivative is taken.
g *= pred.shape[dim_1] / spatial_dims
diffusion = diffusion + g**2

if self.reduction == LossReduction.MEAN.value:
diffusion = torch.mean(diffusion) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
diffusion = torch.sum(diffusion) # sum over the batch and channel dims
elif self.reduction != LossReduction.NONE.value:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return diffusion
116 changes: 116 additions & 0 deletions tests/test_diffusion_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses.deform import DiffusionLoss

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASES = [
# all first partials are zero, so the diffusion loss is also zero
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
# all first partials are one, so the diffusion loss is also one
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
# before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# we have shown in the demo notebook that
# diffusion loss is scale-invariant when the all axes have the same resolution
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# for the following case, consider the following 2D matrix:
# tensor([[[[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]],
# [[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]]]])
# the first partials wrt x are all ones, and so are the first partials wrt y
# the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
[{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
# consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
# the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
# the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
[
{"normalize": True},
{"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
(1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
],
]


class TestDiffusionLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = DiffusionLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_shape(self):
loss = DiffusionLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 3), device=device))
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 2, 5)))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 5, 2)))

# number of vector components unequal to number of spatial dims
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))

def test_ill_opts(self):
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction="unknown")(pred)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction=None)(pred)


if __name__ == "__main__":
unittest.main()

0 comments on commit d5585c3

Please sign in to comment.