Skip to content

Commit

Permalink
Port MONAI Generative utils (Project-MONAI#7134)
Browse files Browse the repository at this point in the history
Towards completing Project-MONAI#6676 .

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mark Graham <markgraham539@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
marksgraham and pre-commit-ci[bot] authored Oct 30, 2023
1 parent 1c17f0e commit 83f5091
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ State Cacher
------------
.. automodule:: monai.utils.state_cacher
:members:

Component store
---------------
.. autoclass:: monai.utils.component_store.ComponentStore
:members:
6 changes: 6 additions & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
AdversarialIterationEvents,
AdversarialKeys,
AlgoKeys,
Average,
BlendMode,
Expand Down Expand Up @@ -47,6 +49,8 @@
MetricReduction,
NdimageMode,
NumpyPadMode,
OrderingTransformations,
OrderingType,
PatchKeys,
PostFix,
ProbMapKeys,
Expand Down Expand Up @@ -95,6 +99,8 @@
str2bool,
str2list,
to_tuple_of_dictionaries,
unsqueeze_left,
unsqueeze_right,
zip_with,
)
from .module import (
Expand Down
65 changes: 65 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

import random
from enum import Enum
from typing import TYPE_CHECKING

from monai.config import IgniteInfo
from monai.utils import deprecated
from monai.utils.module import min_version, optional_import

__all__ = [
"StrEnum",
Expand Down Expand Up @@ -88,6 +91,14 @@ def __repr__(self):
return self.value


if TYPE_CHECKING:
from ignite.engine import EventEnum
else:
EventEnum, _ = optional_import(
"ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
)


class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
Expand Down Expand Up @@ -692,3 +703,57 @@ class AlgoKeys(StrEnum):
ALGO = "algo_instance"
IS_TRAINED = "is_trained"
SCORE = "best_metric"


class AdversarialKeys(StrEnum):
"""
Keys used by the AdversarialTrainer.
`REALS` are real images from the batch.
`FAKES` are fake images generated by the generator. Are the same as PRED.
`REAL_LOGITS` are logits of the discriminator for the real images.
`FAKE_LOGIT` are logits of the discriminator for the fake images.
`RECONSTRUCTION_LOSS` is the loss value computed by the reconstruction loss function.
`GENERATOR_LOSS` is the loss value computed by the generator loss function. It is the
discriminator loss for the fake images. That is backpropagated through the generator only.
`DISCRIMINATOR_LOSS` is the loss value computed by the discriminator loss function. It is the
discriminator loss for the real images and the fake images. That is backpropagated through the
discriminator only.
"""

REALS = "reals"
REAL_LOGITS = "real_logits"
FAKES = "fakes"
FAKE_LOGITS = "fake_logits"
RECONSTRUCTION_LOSS = "reconstruction_loss"
GENERATOR_LOSS = "generator_loss"
DISCRIMINATOR_LOSS = "discriminator_loss"


class AdversarialIterationEvents(EventEnum):
"""
Keys used to define events as used in the AdversarialTrainer.
"""

RECONSTRUCTION_LOSS_COMPLETED = "reconstruction_loss_completed"
GENERATOR_FORWARD_COMPLETED = "generator_forward_completed"
GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED = "generator_discriminator_forward_completed"
GENERATOR_LOSS_COMPLETED = "generator_loss_completed"
GENERATOR_BACKWARD_COMPLETED = "generator_backward_completed"
GENERATOR_MODEL_COMPLETED = "generator_model_completed"
DISCRIMINATOR_REALS_FORWARD_COMPLETED = "discriminator_reals_forward_completed"
DISCRIMINATOR_FAKES_FORWARD_COMPLETED = "discriminator_fakes_forward_completed"
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"


class OrderingType(StrEnum):
RASTER_SCAN = "raster_scan"
S_CURVE = "s_curve"
RANDOM = "random"


class OrderingTransformations(StrEnum):
ROTATE_90 = "rotate_90"
TRANSPOSE = "transpose"
REFLECT = "reflect"
10 changes: 10 additions & 0 deletions monai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,3 +888,13 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
sqrt_num = [int(math.sqrt(_num)) for _num in num]
ret = [_i * _j for _i, _j in zip(sqrt_num, sqrt_num)]
return ensure_tuple(ret) == num


def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(...,) + (None,) * (ndim - arr.ndim)]


def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]
71 changes: 71 additions & 0 deletions tests/test_squeeze_unsqueeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.utils import unsqueeze_left, unsqueeze_right

RIGHT_CASES = [
(np.random.rand(3, 4).astype(np.float32), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.float32), 5, (3, 4, 1, 1, 1)),
(np.random.rand(3, 4).astype(np.float64), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.float64), 5, (3, 4, 1, 1, 1)),
(np.random.rand(3, 4).astype(np.int32), 5, (3, 4, 1, 1, 1)),
(torch.rand(3, 4).type(torch.int32), 5, (3, 4, 1, 1, 1)),
]


LEFT_CASES = [
(np.random.rand(3, 4).astype(np.float32), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.float32), 5, (1, 1, 1, 3, 4)),
(np.random.rand(3, 4).astype(np.float64), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.float64), 5, (1, 1, 1, 3, 4)),
(np.random.rand(3, 4).astype(np.int32), 5, (1, 1, 1, 3, 4)),
(torch.rand(3, 4).type(torch.int32), 5, (1, 1, 1, 3, 4)),
]
ALL_CASES = [
(np.random.rand(3, 4), 2, (3, 4)),
(np.random.rand(3, 4), 0, (3, 4)),
(np.random.rand(3, 4), -1, (3, 4)),
(np.array(3), 4, (1, 1, 1, 1)),
(np.array(3), 0, ()),
(np.random.rand(3, 4).astype(np.int32), 2, (3, 4)),
(np.random.rand(3, 4).astype(np.int32), 0, (3, 4)),
(np.random.rand(3, 4).astype(np.int32), -1, (3, 4)),
(np.array(3).astype(np.int32), 4, (1, 1, 1, 1)),
(np.array(3).astype(np.int32), 0, ()),
(torch.rand(3, 4), 2, (3, 4)),
(torch.rand(3, 4), 0, (3, 4)),
(torch.rand(3, 4), -1, (3, 4)),
(torch.tensor(3), 4, (1, 1, 1, 1)),
(torch.tensor(3), 0, ()),
(torch.rand(3, 4).type(torch.int32), 2, (3, 4)),
(torch.rand(3, 4).type(torch.int32), 0, (3, 4)),
(torch.rand(3, 4).type(torch.int32), -1, (3, 4)),
(torch.tensor(3).type(torch.int32), 4, (1, 1, 1, 1)),
(torch.tensor(3).type(torch.int32), 0, ()),
]


class TestUnsqueeze(unittest.TestCase):
@parameterized.expand(RIGHT_CASES + ALL_CASES)
def test_unsqueeze_right(self, arr, ndim, shape):
self.assertEqual(unsqueeze_right(arr, ndim).shape, shape)

@parameterized.expand(LEFT_CASES + ALL_CASES)
def test_unsqueeze_left(self, arr, ndim, shape):
self.assertEqual(unsqueeze_left(arr, ndim).shape, shape)

0 comments on commit 83f5091

Please sign in to comment.