From 6f404a93c6731d587c9ad73c335b6d9f20d73ed2 Mon Sep 17 00:00:00 2001 From: binliu Date: Sun, 12 Feb 2023 15:42:56 +0000 Subject: [PATCH 1/2] fix the pixelshuffle upsample shape mismatch problem. Signed-off-by: binliu --- monai/networks/nets/flexible_unet.py | 2 +- tests/test_flexible_unet.py | 37 +++++++++++++++------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py index fdb3376748..a880cafdc3 100644 --- a/monai/networks/nets/flexible_unet.py +++ b/monai/networks/nets/flexible_unet.py @@ -309,7 +309,7 @@ def __init__( bias=decoder_bias, upsample=upsample, interp_mode=interp_mode, - pre_conv=None, + pre_conv="default", align_corners=None, is_pad=is_pad, ) diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index aae0cf729a..9251749fde 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -173,29 +173,32 @@ def make_shape_cases( num_classes=10, input_shape=64, norm=("batch", {"eps": 1e-3, "momentum": 0.01}), + upsample=["nontrainable", "deconv", "pixelshuffle"], ): ret_tests = [] for spatial_dim in spatial_dims: # selected spatial_dims for batch in batches: # check single batch as well as multiple batch input for model in models: # selected models for is_pretrained in pretrained: # pretrained or not pretrained - if ("resnet" in model) and is_pretrained: - continue - kwargs = { - "in_channels": in_channels, - "out_channels": num_classes, - "backbone": model, - "pretrained": is_pretrained, - "spatial_dims": spatial_dim, - "norm": norm, - } - ret_tests.append( - [ - kwargs, - (batch, in_channels) + (input_shape,) * spatial_dim, - (batch, num_classes) + (input_shape,) * spatial_dim, - ] - ) + for upsample_method in upsample: + if ("resnet" in model) and is_pretrained: + continue + kwargs = { + "in_channels": in_channels, + "out_channels": num_classes, + "backbone": model, + "pretrained": is_pretrained, + "spatial_dims": spatial_dim, + "norm": norm, + "upsample": upsample_method, + } + ret_tests.append( + [ + kwargs, + (batch, in_channels) + (input_shape,) * spatial_dim, + (batch, num_classes) + (input_shape,) * spatial_dim, + ] + ) return ret_tests From 3c6e752c3a45a4c89eaa89414a2521ee9e53ba11 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 13 Feb 2023 20:55:35 +0000 Subject: [PATCH 2/2] fixes flake8 errors Signed-off-by: Wenqi Li --- .pre-commit-config.yaml | 3 ++- tests/test_flexible_unet.py | 2 +- tests/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d8ca946430..1269e18978 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -54,7 +54,8 @@ repos: exclude: | (?x)^( monai/__init__.py| - docs/source/conf.py + docs/source/conf.py| + tests/utils.py )$ - repo: https://github.com/hadialqattan/pycln diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py index 9251749fde..1218ce6e85 100644 --- a/tests/test_flexible_unet.py +++ b/tests/test_flexible_unet.py @@ -173,7 +173,7 @@ def make_shape_cases( num_classes=10, input_shape=64, norm=("batch", {"eps": 1e-3, "momentum": 0.01}), - upsample=["nontrainable", "deconv", "pixelshuffle"], + upsample=("nontrainable", "deconv", "pixelshuffle"), ): ret_tests = [] for spatial_dim in spatial_dims: # selected spatial_dims diff --git a/tests/utils.py b/tests/utils.py index 2f4b6d81ac..e0c061f755 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -49,7 +49,7 @@ from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") -http_error, has_requests = optional_import("requests", name="HTTPError") +http_error, has_req = optional_import("requests", name="HTTPError") quick_test_var = "QUICKTEST" _tf32_enabled = None @@ -126,7 +126,7 @@ def assert_allclose( def skip_if_downloading_fails(): try: yield - except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_requests else () as e: + except (ContentTooShortError, HTTPError, ConnectionError) + (http_error,) if has_req else () as e: # noqa: B030 raise unittest.SkipTest(f"error while downloading: {e}") from e except ssl.SSLError as ssl_e: if "decryption failed" in str(ssl_e):