Skip to content

Commit

Permalink
reformatted from black
Browse files Browse the repository at this point in the history
Signed-off-by: Han Wang <freddie.wanah@gmail.com>
  • Loading branch information
freddiewanah committed Apr 18, 2024
1 parent 3a48ead commit 00b1465
Showing 9 changed files with 169 additions and 101 deletions.
26 changes: 19 additions & 7 deletions tests/test_affine_transform.py
Original file line number Diff line number Diff line change
@@ -133,13 +133,25 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):

class TestAffineTransform(unittest.TestCase):

@parameterized.expand([
("shift", torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),
("shift_1", torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]),
[[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),
("shift_2", torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]),
[[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),
])
@parameterized.expand(
[
(
"shift",
torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]),
[[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]],
),
(
"shift_1",
torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]),
[[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]],
),
(
"shift_2",
torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]),
[[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]],
),
]
)
def test_affine_transforms(self, name, affine, expected):
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
30 changes: 19 additions & 11 deletions tests/test_global_mutual_information_loss.py
Original file line number Diff line number Diff line change
@@ -117,22 +117,30 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.

class TestGlobalMutualInformationLossIll(unittest.TestCase):

@parameterized.expand([
("mismatched_simple_dims", torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),
("mismatched_advanced_dims", torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),
# You can add more test cases as needed
])
@parameterized.expand(
[
("mismatched_simple_dims", torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)),
(
"mismatched_advanced_dims",
torch.ones((1, 3, 3), dtype=torch.float),
torch.ones((1, 3), dtype=torch.float),
),
# You can add more test cases as needed
]
)
def test_ill_shape(self, name, input1, input2):
loss = GlobalMutualInformationLoss()
with self.assertRaises(ValueError):
loss.forward(input1, input2)

@parameterized.expand([
("num_bins_zero", 0, "mean", ValueError, ""),
("num_bins_negative", -1, "mean", ValueError, ""),
("reduction_unknown", 64, "unknown", ValueError, ""),
("reduction_none", 64, None, ValueError, ""),
])
@parameterized.expand(
[
("num_bins_zero", 0, "mean", ValueError, ""),
("num_bins_negative", -1, "mean", ValueError, ""),
("reduction_unknown", 64, "unknown", ValueError, ""),
("reduction_none", 64, None, ValueError, ""),
]
)
def test_ill_opts(self, name, num_bins, reduction, expected_exception, expected_message):
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
24 changes: 14 additions & 10 deletions tests/test_hausdorff_loss.py
Original file line number Diff line number Diff line change
@@ -219,11 +219,13 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
HausdorffDTLoss(reduction=None)(chn_input, chn_target)

@parameterized.expand([
(False, False, False),
(False, True, False),
(False, False, True),
])
@parameterized.expand(
[
(False, False, False),
(False, True, False),
(False, False, True),
]
)
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
@@ -255,11 +257,13 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)

@parameterized.expand([
(False, False, False),
(False, True, False),
(False, False, True),
])
@parameterized.expand(
[
(False, False, False),
(False, True, False),
(False, False, True),
]
)
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
18 changes: 12 additions & 6 deletions tests/test_multi_scale.py
Original file line number Diff line number Diff line change
@@ -58,12 +58,18 @@ def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

@parameterized.expand([
("kernel_none", {"loss": dice_loss, "kernel": "none"}, None, None),
("scales_negative", {"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),
("scales_negative_reduction_none", {"loss": dice_loss, "scales": [-1], "reduction": "none"},
torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),
])
@parameterized.expand(
[
("kernel_none", {"loss": dice_loss, "kernel": "none"}, None, None),
("scales_negative", {"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))),
(
"scales_negative_reduction_none",
{"loss": dice_loss, "scales": [-1], "reduction": "none"},
torch.ones((1, 1, 3)),
torch.ones((1, 1, 3)),
),
]
)
def test_ill_opts(self, name, kwargs, input, target):
if input is None and target is None:
with self.assertRaisesRegex(ValueError, ""):
35 changes: 21 additions & 14 deletions tests/test_prepare_batch_default.py
Original file line number Diff line number Diff line change
@@ -28,20 +28,27 @@ def forward(self, x: torch.Tensor):

class TestPrepareBatchDefault(unittest.TestCase):

@parameterized.expand([
("dict_content", [
{
"image": torch.tensor([1, 2]),
"label": torch.tensor([3, 4]),
"extra1": torch.tensor([5, 6]),
"extra2": 16,
"extra3": "test",
}
], TestNet(), True),
("tensor_content", [torch.tensor([1, 2])], torch.nn.Identity(), True),
("pair_content", [(torch.tensor([1, 2]), torch.tensor([3, 4]))], torch.nn.Identity(), True),
("empty_data", [], TestNet(), False),
])
@parameterized.expand(
[
(
"dict_content",
[
{
"image": torch.tensor([1, 2]),
"label": torch.tensor([3, 4]),
"extra1": torch.tensor([5, 6]),
"extra2": 16,
"extra3": "test",
}
],
TestNet(),
True,
),
("tensor_content", [torch.tensor([1, 2])], torch.nn.Identity(), True),
("pair_content", [(torch.tensor([1, 2]), torch.tensor([3, 4]))], torch.nn.Identity(), True),
("empty_data", [], TestNet(), False),
]
)
def test_prepare_batch(self, name, dataloader, network, should_run):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluator = SupervisedEvaluator(
12 changes: 7 additions & 5 deletions tests/test_tversky_loss.py
Original file line number Diff line number Diff line change
@@ -165,11 +165,13 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
TverskyLoss(reduction=None)(chn_input, chn_target)

@parameterized.expand([
(False, False, False),
(False, True, False),
(False, False, True),
])
@parameterized.expand(
[
(False, False, False),
(False, True, False),
(False, False, True),
]
)
def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 3))
chn_target = torch.ones((1, 1, 3))
61 changes: 32 additions & 29 deletions tests/test_ultrasound_confidence_map_transform.py
Original file line number Diff line number Diff line change
@@ -536,12 +536,14 @@ def test_parameters(self):
with self.assertRaises(ValueError):
UltrasoundConfidenceMapTransform(sink_mode="unknown")

@parameterized.expand([
("all", SINK_ALL_OUTPUT),
("mid", SINK_MID_OUTPUT),
("min", SINK_MIN_OUTPUT),
("mask", SINK_MASK_OUTPUT, True),
])
@parameterized.expand(
[
("all", SINK_ALL_OUTPUT),
("mid", SINK_MID_OUTPUT),
("min", SINK_MIN_OUTPUT),
("mask", SINK_MASK_OUTPUT, True),
]
)
def test_ultrasound_confidence_map_transform(self, sink_mode, expected_output, use_mask=False):
# RGB image
input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 3, axis=0), axis=0)
@@ -561,12 +563,14 @@ def test_ultrasound_confidence_map_transform(self, sink_mode, expected_output, u
self.assertIsInstance(result_np, np.ndarray)
assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)

@parameterized.expand([
("all", SINK_ALL_OUTPUT),
("mid", SINK_MID_OUTPUT),
("min", SINK_MIN_OUTPUT),
("mask", SINK_MASK_OUTPUT, True), # Adding a flag for mask cases
])
@parameterized.expand(
[
("all", SINK_ALL_OUTPUT),
("mid", SINK_MID_OUTPUT),
("min", SINK_MIN_OUTPUT),
("mask", SINK_MASK_OUTPUT, True), # Adding a flag for mask cases
]
)
def test_multi_channel_2d(self, sink_mode, expected_output, use_mask=False):
input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 17, axis=0), axis=0)
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
@@ -585,12 +589,14 @@ def test_multi_channel_2d(self, sink_mode, expected_output, use_mask=False):
self.assertIsInstance(result_np, np.ndarray)
assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)

@parameterized.expand([
("all",),
("mid",),
("min",),
("mask",),
])
@parameterized.expand(
[
("all",),
("mid",),
("min",),
("mask",),
]
)
def test_non_one_first_dim(self, sink_mode):
transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)
input_img_rgb = np.repeat(self.input_img_np, 3, axis=0)
@@ -607,12 +613,7 @@ def test_non_one_first_dim(self, sink_mode):
with self.assertRaises(ValueError):
transform(input_img_rgb)

@parameterized.expand([
("all",),
("mid",),
("min",),
("mask",)
])
@parameterized.expand([("all",), ("mid",), ("min",), ("mask",)])
def test_no_first_dim(self, sink_mode):
input_img_rgb = self.input_img_np[0]
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
@@ -630,11 +631,13 @@ def test_no_first_dim(self, sink_mode):
with self.assertRaises(ValueError):
transform(input_img_rgb, self.input_mask_np)

@parameterized.expand([
("all",),
("mid",),
("min",),
])
@parameterized.expand(
[
("all",),
("mid",),
("min",),
]
)
def test_sink_mode(self, mode):
transform = UltrasoundConfidenceMapTransform(sink_mode=mode)

31 changes: 22 additions & 9 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
@@ -69,15 +69,28 @@ def test_shape(self, input_param, input_shape, expected_shape):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand([
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
])
def test_ill_arg(self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, pos_embed,
classification, dropout_rate):
@parameterized.expand(
[
(1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
(1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
(1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
(1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
(4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
]
)
def test_ill_arg(
self,
in_channels,
img_size,
patch_size,
hidden_size,
mlp_dim,
num_layers,
num_heads,
pos_embed,
classification,
dropout_rate,
):
with self.assertRaises(ValueError):
ViT(
in_channels=in_channels,
33 changes: 23 additions & 10 deletions tests/test_vitautoenc.py
Original file line number Diff line number Diff line change
@@ -82,16 +82,29 @@ def test_shape(self, input_param, input_shape, expected_shape):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand([
("img_size_too_large_for_patch_size", 1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", 0.3),
("num_heads_out_of_bound", 1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", 0.3),
("img_size_not_divisible_by_patch_size", 1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", 0.3),
("invalid_pos_embed", 4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", 0.3),
("patch_size_not_divisible", 4, (96, 96, 96), (9, 9, 9), 768, 3072, 12, 12, "perc", 0.3),
# Add more test cases as needed
])
def test_ill_arg(self, name, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads,
pos_embed, dropout_rate):
@parameterized.expand(
[
("img_size_too_large_for_patch_size", 1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", 0.3),
("num_heads_out_of_bound", 1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", 0.3),
("img_size_not_divisible_by_patch_size", 1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", 0.3),
("invalid_pos_embed", 4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", 0.3),
("patch_size_not_divisible", 4, (96, 96, 96), (9, 9, 9), 768, 3072, 12, 12, "perc", 0.3),
# Add more test cases as needed
]
)
def test_ill_arg(
self,
name,
in_channels,
img_size,
patch_size,
hidden_size,
mlp_dim,
num_layers,
num_heads,
pos_embed,
dropout_rate,
):
with self.assertRaises(ValueError):
ViTAutoEnc(
in_channels=in_channels,

0 comments on commit 00b1465

Please sign in to comment.