Skip to content

Commit

Permalink
Merge similar test components with parameterized
Browse files Browse the repository at this point in the history
Signed-off-by: Han Wang <freddie.wanah@gmail.com>
  • Loading branch information
freddiewanah committed Apr 18, 2024
1 parent 16d4e2f commit 3a48ead
Show file tree
Hide file tree
Showing 15 changed files with 251 additions and 472 deletions.
27 changes: 8 additions & 19 deletions tests/test_affine_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,28 +133,17 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):

class TestAffineTransform(unittest.TestCase):

def test_affine_shift(self):
affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]])
@parameterized.expand([
("shift", torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),
("shift_1", torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]),
[[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),
("shift_2", torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]),
[[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),
])
def test_affine_transforms(self, name, affine, expected):
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_affine_shift_1(self):
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]])
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_affine_shift_2(self):
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)

def test_zoom(self):
Expand Down
46 changes: 26 additions & 20 deletions tests/test_compute_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import torch
from parameterized import parameterized

from monai.metrics import FBetaScore
from tests.utils import assert_allclose
Expand All @@ -33,26 +34,31 @@ def test_expecting_success_and_device(self):
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
np.testing.assert_equal(result.device, y_pred.device)

def test_expecting_success2(self):
metric = FBetaScore(beta=0.5)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.609756]), atol=1e-6, rtol=1e-6)

def test_expecting_success3(self):
metric = FBetaScore(beta=2)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.862069]), atol=1e-6, rtol=1e-6)

def test_denominator_is_zero(self):
metric = FBetaScore(beta=2)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
)
assert_allclose(metric.aggregate()[0], torch.Tensor([0.0]), atol=1e-6, rtol=1e-6)
@parameterized.expand(
[
(
"success_beta_0_5",
FBetaScore(beta=0.5),
torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]),
torch.Tensor([0.609756]),
),
(
"success_beta_2",
FBetaScore(beta=2),
torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]),
torch.Tensor([0.862069]),
),
(
"denominator_zero",
FBetaScore(beta=2),
torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
torch.Tensor([0.0]),
),
]
)
def test_success_and_zero(self, name, metric, y, expected_score):
metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y)
assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6)

def test_number_of_dimensions_less_than_2_should_raise_error(self):
metric = FBetaScore()
Expand Down
34 changes: 19 additions & 15 deletions tests/test_global_mutual_information_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import torch
from parameterized import parameterized

from monai import transforms
from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
Expand Down Expand Up @@ -116,24 +117,27 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.

class TestGlobalMutualInformationLossIll(unittest.TestCase):

def test_ill_shape(self):
@parameterized.expand([
("mismatched_simple_dims", torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),
("mismatched_advanced_dims", torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),
# You can add more test cases as needed
])
def test_ill_shape(self, name, input1, input2):
loss = GlobalMutualInformationLoss()
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))

def test_ill_opts(self):
with self.assertRaises(ValueError):
loss.forward(input1, input2)

@parameterized.expand([
("num_bins_zero", 0, "mean", ValueError, ""),
("num_bins_negative", -1, "mean", ValueError, ""),
("reduction_unknown", 64, "unknown", ValueError, ""),
("reduction_none", 64, None, ValueError, ""),
])
def test_ill_opts(self, name, num_bins, reduction, expected_exception, expected_message):
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(num_bins=0)(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(num_bins=-1)(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(reduction="unknown")(pred, target)
with self.assertRaisesRegex(ValueError, ""):
GlobalMutualInformationLoss(reduction=None)(pred, target)
with self.assertRaisesRegex(expected_exception, expected_message):
GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)


if __name__ == "__main__":
Expand Down
30 changes: 14 additions & 16 deletions tests/test_hausdorff_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,17 +219,16 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
HausdorffDTLoss(reduction=None)(chn_input, chn_target)

def test_input_warnings(self):
@parameterized.expand([
(False, False, False),
(False, True, False),
(False, False, True),
])
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
loss = HausdorffDTLoss(include_background=False)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = HausdorffDTLoss(softmax=True)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = HausdorffDTLoss(to_onehot_y=True)
loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)


Expand All @@ -256,17 +255,16 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)

def test_input_warnings(self):
@parameterized.expand([
(False, False, False),
(False, True, False),
(False, False, True),
])
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(include_background=False)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(softmax=True)
loss.forward(chn_input, chn_target)
with self.assertWarns(Warning):
loss = LogHausdorffDTLoss(to_onehot_y=True)
loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)


Expand Down
30 changes: 14 additions & 16 deletions tests/test_median_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,25 @@

import numpy as np
import torch
from parameterized import parameterized

from monai.networks.layers import MedianFilter


class MedianFilterTestCase(unittest.TestCase):

def test_3d_big(self):
a = torch.ones(1, 1, 2, 3, 5)
g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))

expected = a.numpy()
out = g(a).cpu().numpy()
np.testing.assert_allclose(out, expected, rtol=1e-5)

def test_3d(self):
a = torch.ones(1, 1, 4, 3, 4)
g = MedianFilter(1).to(torch.device("cpu:0"))

expected = a.numpy()
out = g(a).cpu().numpy()
np.testing.assert_allclose(out, expected, rtol=1e-5)
@parameterized.expand(
[
("3d_big", torch.ones(1, 1, 2, 3, 5), MedianFilter([1, 2, 4])),
("3d", torch.ones(1, 1, 4, 3, 4), MedianFilter(1)),
]
)
def test_3d(self, name, input_tensor, filter):
filter = filter.to(torch.device("cpu:0"))

expected = input_tensor.numpy()
output = filter(input_tensor).cpu().numpy()

np.testing.assert_allclose(output, expected, rtol=1e-5)

def test_3d_radii(self):
a = torch.ones(1, 1, 4, 3, 2)
Expand Down
24 changes: 13 additions & 11 deletions tests/test_multi_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,19 @@ def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, kernel="none")
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, scales=[-1])(
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
)
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
)
@parameterized.expand([
("kernel_none", {"loss": dice_loss, "kernel": "none"}, None, None),
("scales_negative", {"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),
("scales_negative_reduction_none", {"loss": dice_loss, "scales": [-1], "reduction": "none"},
torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),
])
def test_ill_opts(self, name, kwargs, input, target):
if input is None and target is None:
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(**kwargs)
else:
with self.assertRaisesRegex(ValueError, ""):
MultiScaleLoss(**kwargs)(input, target)

def test_script(self):
input_param, input_data, expected_val = TEST_CASES[0]
Expand Down
27 changes: 8 additions & 19 deletions tests/test_optional_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@

import unittest

from parameterized import parameterized

from monai.utils import OptionalImportError, exact_version, optional_import


class TestOptionalImport(unittest.TestCase):

def test_default(self):
my_module, flag = optional_import("not_a_module")
@parameterized.expand(["not_a_module", "torch.randint"])
def test_default(self, import_module):
my_module, flag = optional_import(import_module)
self.assertFalse(flag)
with self.assertRaises(OptionalImportError):
my_module.test

my_module, flag = optional_import("torch.randint")
with self.assertRaises(OptionalImportError):
self.assertFalse(flag)
print(my_module.test)

def test_import_valid(self):
my_module, flag = optional_import("torch")
self.assertTrue(flag)
Expand All @@ -47,18 +45,9 @@ def test_import_wrong_number(self):
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

def test_import_good_number(self):
my_module, flag = optional_import("torch", "0")
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

my_module, flag = optional_import("torch", "0.0.0.1")
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))

my_module, flag = optional_import("torch", "1.1.0")
@parameterized.expand(["0", "0.0.0.1", "1.1.0"])
def test_import_good_number(self, version_number):
my_module, flag = optional_import("torch", version_number)
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))
Expand Down
8 changes: 3 additions & 5 deletions tests/test_perceptual_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ def test_1d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=1)

def test_medicalnet_on_2d_data(self):
@parameterized.expand(["medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets"])
def test_medicalnet_on_2d_data(self, network_type):
with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")

with self.assertRaises(ValueError):
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
PerceptualLoss(spatial_dims=2, network_type=network_type)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 3a48ead

Please sign in to comment.