Skip to content

Commit

Permalink
Add customisable sampling weights
Browse files Browse the repository at this point in the history
  • Loading branch information
tuanchien committed Mar 16, 2023
1 parent 4dbb3c5 commit 8d3e5dd
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
29 changes: 27 additions & 2 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import warnings
from collections.abc import Callable, Mapping, Sequence
from typing import Any
from typing import Any, List, Optional

import numpy as np

Expand Down Expand Up @@ -383,6 +383,7 @@ class SomeOf(Compose):
max_num_transforms: maximum number of transforms to sample. Defaults to `3`.
fixed: whether to sample exactly `max_num_transforms` transforms, or up to it. Defaults to `False`.
replace: whether to sample with replacement. Defaults to `False`.
weights: weights to use in for sampling transforms. Will be normalized to 1. Default: None (uniform).
"""

def __init__(
Expand All @@ -395,6 +396,7 @@ def __init__(
max_num_transforms: int = 3,
fixed: bool = True,
replace: bool = False,
weights: Optional[List[int]] = None,
) -> None:
super().__init__(transforms, map_items, unpack_items, log_stats)
if transforms is None:
Expand All @@ -404,6 +406,29 @@ def __init__(
self.max_num_transforms = min(self.n_transforms, max_num_transforms)
self.fixed = fixed
self.replace = replace
self.weights = self._normalize_probabilities(weights)

def _normalize_probabilities(self, weights):
if weights is None or self.n_transforms == 0:
return None

weights = np.array(weights)

n_weights = len(weights)
if n_weights != self.n_transforms:
raise ValueError(
f"The number of weights specified must be equal to the number of transforms provided. Expected: {self.n_transforms}, got: {n_weights}."
)

if np.any(weights < 0):
raise ValueError(f"Probabilities must be greater than or equal to zero, got {weights}.")

if np.all(weights == 0):
raise ValueError(f"At least one probability must be greater than zero, got {weights}.")

weights = weights / weights.sum()

return ensure_tuple(list(weights))

def __call__(self, data):
if self.n_transforms == 0:
Expand All @@ -415,7 +440,7 @@ def __call__(self, data):
else self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)
)

applied_order = self.R.choice(self.n_transforms, sample_size, replace=self.replace).tolist()
applied_order = self.R.choice(self.n_transforms, sample_size, replace=self.replace, p=self.weights).tolist()
for i in applied_order:
data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats)

Expand Down
16 changes: 16 additions & 0 deletions tests/test_some_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ def test_inverse(self, transform, invertible, use_metatensor):
# if not invertible, should not change the data
self.assertDictEqual(fwd_data[i], _fwd_inv_data)

def test_normalize_weights(self):
tr = SomeOf((A(), B(), C()), fixed=True, max_num_transforms=1, weights=(1, 2, 1))
self.assertTupleEqual(tr.weights, (0.25, 0.5, 0.25))

tr = SomeOf((), fixed=True, max_num_transforms=1, weights=(1, 2, 1))
self.assertIsNone(tr.weights)

def test_no_weights_arg(self):
tr = SomeOf((A(), B(), C(), D()), fixed=True, max_num_transforms=1)
self.assertIsNone(tr.weights)

def test_bad_weights(self):
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(1, 2))
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(0, 0, 0))
self.assertRaises(ValueError, SomeOf, (A(), B(), C()), fixed=True, max_num_transforms=1, weights=(-1, 1, 1))


if __name__ == "__main__":
unittest.main()

0 comments on commit 8d3e5dd

Please sign in to comment.