Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6676 port losses from monai-generative #6729

Merged
merged 26 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a21c7f1
Adds losses
marksgraham Jul 14, 2023
79a3784
Adds LPIPS as a requirement
marksgraham Jul 14, 2023
89c4e83
Updates docstrings
marksgraham Jul 19, 2023
3faaca5
Adds external dependency to relevant files, excludes from min tests a…
marksgraham Jul 19, 2023
086b8a9
Uses optional_import for torchvision too
marksgraham Jul 19, 2023
f500f4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
105c3b8
Fixes typing issues in perceptual loss
marksgraham Jul 31, 2023
700096d
Fixes typing issues in adversarial loss
marksgraham Jul 31, 2023
ea158ca
Fixes merge conflicts
marksgraham Jul 31, 2023
4b1d801
Merge branch 'dev' into 6676_port_generative_losses
marksgraham Jul 31, 2023
cd01b59
Fixes more typing errors
marksgraham Jul 31, 2023
a5e092e
Formatting fix
marksgraham Jul 31, 2023
a241919
DCO Remediation Commit for Mark Graham <markgraham539@gmail.com>
marksgraham Jul 31, 2023
36d7ed6
Empty commit workaround undo
marksgraham Jul 31, 2023
d9d1140
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2023
413f394
Merge branch '6676_port_generative_losses' of github.com:marksgraham/…
marksgraham Jul 31, 2023
c118d10
DCO Remediation Commit for Mark Graham <markgraham539@gmail.com>
marksgraham Jul 31, 2023
f8f1a3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2023
6260984
Merge branch 'dev' into 6676_port_generative_losses
wyli Jul 31, 2023
a4e16c8
Fix errors with earlier torchvision versions
marksgraham Aug 2, 2023
569e4e2
Merge branch '6676_port_generative_losses' of github.com:marksgraham/…
marksgraham Aug 2, 2023
42414ff
Adds warning if user specifies cache_dir
marksgraham Aug 2, 2023
cfb24b4
Reverts to warivto for modelhub and changes perceptual test decorator…
marksgraham Aug 3, 2023
ebfe2ef
Addresses comments from mingxin-zheng
marksgraham Aug 3, 2023
840e5cc
Fixes codeformat list comprehension error
marksgraham Aug 3, 2023
345d82c
Merge branch 'dev' into 6676_port_generative_losses
wyli Aug 3, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,11 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
- The options are

```
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr]
[nibabel, skimage, scipy, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, clearml, matplotlib, tensorboardX, tifffile, imagecodecs, pyyaml, fire, jsonschema, ninja, pynrrd, pydicom, h5py, nni, optuna, onnx, onnxruntime, zarr, lpips]
```

which correspond to `nibabel`, `scikit-image`, `scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, and `zarr` respectively.
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr` and `lpips` respectively.


- `pip install 'monai[all]'` installs all the optional dependencies.
15 changes: 15 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,21 @@ Reconstruction Losses
.. autoclass:: monai.losses.ssim_loss.SSIMLoss
:members:

`PatchAdversarialLoss`
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: PatchAdversarialLoss
:members:

`PerceptualLoss`
~~~~~~~~~~~~~~~~~
.. autoclass:: PerceptualLoss
:members:

`JukeboxLoss`
~~~~~~~~~~~~~~
.. autoclass:: JukeboxLoss
:members:


Loss Wrappers
-------------
Expand Down
3 changes: 3 additions & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
Expand All @@ -34,7 +35,9 @@
from .giou_loss import BoxGIoULoss, giou
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
from .ssim_loss import SSIMLoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
176 changes: 176 additions & 0 deletions monai/losses/adversarial_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings

import torch
from torch.nn.modules.loss import _Loss

from monai.networks.layers.utils import get_act_layer
from monai.utils import LossReduction
from monai.utils.enums import StrEnum


class AdversarialCriterions(StrEnum):
BCE = "bce"
HINGE = "hinge"
LEAST_SQUARE = "least_squares"


class PatchAdversarialLoss(_Loss):
"""
Calculates an adversarial loss on a Patch Discriminator or a Multi-scale Patch Discriminator.
Warning: due to the possibility of using different criterions, the output of the discrimination
mustn't be passed to a final activation layer. That is taken care of internally within the loss.

Args:
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.

criterion: which criterion (hinge, least_squares or bce) you want to use on the discriminators outputs.
Depending on the criterion, a different activation layer will be used. Make sure you don't run the outputs
through an activation layer prior to calling the loss.
no_activation_leastsq: if True, the activation layer in the case of least-squares is removed.
"""

def __init__(
self,
reduction: LossReduction | str = LossReduction.MEAN,
criterion: str = AdversarialCriterions.LEAST_SQUARE.value,
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
no_activation_leastsq: bool = False,
) -> None:
super().__init__(reduction=LossReduction(reduction).value)

if criterion.lower() not in [m.value for m in AdversarialCriterions]:
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
% ", ".join([m.value for m in AdversarialCriterions])
)

# Depending on the criterion, a different activation layer is used.
self.real_label = 1.0
self.fake_label = 0.0
self.loss_fct: _Loss
if criterion == AdversarialCriterions.BCE.value:
self.activation = get_act_layer("SIGMOID")
self.loss_fct = torch.nn.BCELoss(reduction=reduction)
elif criterion == AdversarialCriterions.HINGE.value:
self.activation = get_act_layer("TANH")
self.fake_label = -1.0
elif criterion == AdversarialCriterions.LEAST_SQUARE.value:
if no_activation_leastsq:
self.activation = None
else:
self.activation = get_act_layer(name=("LEAKYRELU", {"negative_slope": 0.05}))
self.loss_fct = torch.nn.MSELoss(reduction=reduction)

self.criterion = criterion
self.reduction = reduction

def get_target_tensor(self, input: torch.Tensor, target_is_real: bool) -> torch.Tensor:
"""
Gets the ground truth tensor for the discriminator depending on whether the input is real or fake.

Args:
input: input tensor from the discriminator (output of discriminator, or output of one of the multi-scale
discriminator). This is used to match the shape.
target_is_real: whether the input is real or wannabe-real (1s) or fake (0s).
Returns:
"""
filling_label = self.real_label if target_is_real else self.fake_label
label_tensor = torch.tensor(1).fill_(filling_label).type(input.type()).to(input[0].device)
label_tensor.requires_grad_(False)
return label_tensor.expand_as(input)

def get_zero_tensor(self, input: torch.Tensor) -> torch.Tensor:
"""
Gets a zero tensor.

Args:
input: tensor which shape you want the zeros tensor to correspond to.
Returns:
"""

zero_label_tensor = torch.tensor(0).type(input[0].type()).to(input[0].device)
zero_label_tensor.requires_grad_(False)
return zero_label_tensor.expand_as(input)

def forward(
self, input: torch.Tensor | list, target_is_real: bool, for_discriminator: bool
) -> torch.Tensor | list[torch.Tensor]:
"""

Args:
input: output of Multi-Scale Patch Discriminator or Patch Discriminator; being a list of tensors
or a tensor; they shouldn't have gone through an activation layer.
target_is_real: whereas the input corresponds to discriminator output for real or fake images
for_discriminator: whereas this is being calculated for discriminator or generator loss. In the last
case, target_is_real is set to True, as the generator wants the input to be dimmed as real.
Returns: if reduction is None, returns a list with the loss tensors of each discriminator if multi-scale
discriminator is active, or the loss tensor if there is just one discriminator. Otherwise, it returns the
summed or mean loss over the tensor and discriminator/s.

"""

if not for_discriminator and not target_is_real:
target_is_real = True # With generator, we always want this to be true!
warnings.warn(
"Variable target_is_real has been set to False, but for_discriminator is set"
"to False. To optimise a generator, target_is_real must be set to True."
)

if type(input) is not list:
input = [input]
target_ = []
for _, disc_out in enumerate(input):
if self.criterion != AdversarialCriterions.HINGE.value:
target_.append(self.get_target_tensor(disc_out, target_is_real))
else:
target_.append(self.get_zero_tensor(disc_out))

# Loss calculation
loss_list = []
for disc_ind, disc_out in enumerate(input):
if self.activation is not None:
disc_out = self.activation(disc_out)
if self.criterion == AdversarialCriterions.HINGE.value and not target_is_real:
loss_ = self.forward_single(-disc_out, target_[disc_ind])
else:
loss_ = self.forward_single(disc_out, target_[disc_ind])
loss_list.append(loss_)

loss: torch.Tensor | list[torch.Tensor]
if loss_list is not None:
if self.reduction == LossReduction.MEAN.value:
loss = torch.mean(torch.stack(loss_list))
elif self.reduction == LossReduction.SUM.value:
loss = torch.sum(torch.stack(loss_list))
else:
loss = loss_list
return loss

def forward_single(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
forward: torch.Tensor
if (
self.criterion == AdversarialCriterions.BCE.value
or self.criterion == AdversarialCriterions.LEAST_SQUARE.value
):
forward = self.loss_fct(input, target)
elif self.criterion == AdversarialCriterions.HINGE.value:
minval = torch.min(input - 1, self.get_zero_tensor(input))
forward = -torch.mean(minval)
return forward
Loading