From 3b94d62b0235934538e695668e893b4901e70741 Mon Sep 17 00:00:00 2001 From: Mingxin Zheng Date: Mon, 19 Jun 2023 16:22:59 +0800 Subject: [PATCH 1/3] Fix test_image_filter test error Signed-off-by: Mingxin Zheng --- monai/transforms/utility/array.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 6c917a9f0a..68480ddcd5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1641,7 +1641,22 @@ def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter "`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" ) - def _check_kwargs_are_present(self, filter, **kwargs): + def _check_kwargs_are_present(self, filter: NdarrayOrTensor | str, **kwargs: Any) -> None: + """ + Perform sanity checks on the kwargs if the filter contains the required keys. + If the filter is ``gauss``, kwargs should contain ``sigma``. + If the filter is ``savitzky_golay``, kwargs should contain ``order``. + + Args: + filter: a number array in tensor/numpy or a string indicating the filter type, e.g. gauss/savitzky_golay. + kwargs: additional arguments defining the filter. + + Raises: + KeyError if the filter doesn't contain the requirement key. + """ + + if not isinstance(filter, str): + return if filter == "gauss" and "sigma" not in kwargs.keys(): raise KeyError("`filter='gauss', requires the additional keyword argument `sigma`") if filter == "savitzky_golay" and "order" not in kwargs.keys(): From 9fd49be56a64b96b41870b3a5168238c21b9d323 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Jun 2023 08:24:30 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/transforms/utility/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 68480ddcd5..c067970d10 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1650,7 +1650,7 @@ def _check_kwargs_are_present(self, filter: NdarrayOrTensor | str, **kwargs: Any Args: filter: a number array in tensor/numpy or a string indicating the filter type, e.g. gauss/savitzky_golay. kwargs: additional arguments defining the filter. - + Raises: KeyError if the filter doesn't contain the requirement key. """ From 1be64a406cfc828b063477b1e1a5487570c15b4a Mon Sep 17 00:00:00 2001 From: Mingxin <18563433+mingxin-zheng@users.noreply.github.com> Date: Mon, 19 Jun 2023 08:29:14 +0000 Subject: [PATCH 3/3] Fix mypy Signed-off-by: Mingxin <18563433+mingxin-zheng@users.noreply.github.com> --- monai/transforms/utility/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index c067970d10..75b8199314 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1641,14 +1641,14 @@ def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter "`class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`" ) - def _check_kwargs_are_present(self, filter: NdarrayOrTensor | str, **kwargs: Any) -> None: + def _check_kwargs_are_present(self, filter: str | NdarrayOrTensor | nn.Module, **kwargs: Any) -> None: """ Perform sanity checks on the kwargs if the filter contains the required keys. If the filter is ``gauss``, kwargs should contain ``sigma``. If the filter is ``savitzky_golay``, kwargs should contain ``order``. Args: - filter: a number array in tensor/numpy or a string indicating the filter type, e.g. gauss/savitzky_golay. + filter: A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``. kwargs: additional arguments defining the filter. Raises: