Skip to content

Commit

Permalink
5782 Flexible interp modes in regunet (#5807)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>

Fixes #5782

### Description
- adds 'mode' and 'align_corners' options to the blocks and nets
- fixes a few typos

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Jan 5, 2023
1 parent 315d2d2 commit df6bc9c
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [windows-latest, ubuntu-latest]
os: [ubuntu-latest]
python-version: ["3.9"]
runs-on: ${{ matrix.os }}
env:
Expand Down
21 changes: 16 additions & 5 deletions monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:

class LocalNetUpSampleBlock(nn.Module):
"""
A up-sample module that can be used for LocalNet, based on:
An up-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
Expand All @@ -176,12 +176,21 @@ class LocalNetUpSampleBlock(nn.Module):
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: Optional[bool] = None,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
mode: interpolation mode of the additive upsampling, default to 'nearest'.
align_corners: whether to align corners for the additive upsampling, default to None.
Raises:
ValueError: when ``in_channels != 2 * out_channels``
"""
Expand All @@ -199,9 +208,11 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> No
f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.out_channels = out_channels
self.mode = mode
self.align_corners = align_corners

def addictive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:])
def additive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners)
# [(batch, out_channels, ...), (batch, out_channels, ...)]
x = x.split(split_size=int(self.out_channels), dim=1)
# (batch, out_channels, ...)
Expand All @@ -226,7 +237,7 @@ def forward(self, x, mid) -> torch.Tensor:
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
f"got x of shape {x.shape}, mid of shape {mid.shape}"
)
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
h0 = self.deconv_block(x) + self.additive_upsampling(x, mid)
r1 = h0 + mid
r2 = self.conv_block(h0)
out: torch.Tensor = self.residual_block(r2, r1)
Expand Down
10 changes: 9 additions & 1 deletion monai/networks/blocks/regunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def __init__(
out_channels: int,
kernel_initializer: Optional[str] = "kaiming_uniform",
activation: Optional[str] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
"""
Expand All @@ -211,6 +213,8 @@ def __init__(
out_channels: number of output channels
kernel_initializer: kernel initializer
activation: kernel activation function
mode: feature map interpolation mode, default to "nearest".
align_corners: whether to align corners for feature map interpolation.
"""
super().__init__()
self.extract_levels = extract_levels
Expand All @@ -228,6 +232,8 @@ def __init__(
for d in extract_levels
]
)
self.mode = mode
self.align_corners = align_corners

def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
"""
Expand All @@ -240,7 +246,9 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size``
"""
feature_list = [
F.interpolate(layer(x[self.max_level - level]), size=image_size)
F.interpolate(
layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners
)
for layer, level in zip(self.layers, self.extract_levels)
]
out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0)
Expand Down
31 changes: 25 additions & 6 deletions monai/networks/nets/regunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,23 @@ def build_output_block(self):


class AdditiveUpSampleBlock(nn.Module):
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
super().__init__()
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
self.mode = mode
self.align_corners = align_corners

def forward(self, x: torch.Tensor) -> torch.Tensor:
output_size = [size * 2 for size in x.shape[2:]]
deconved = self.deconv(x)
resized = F.interpolate(x, output_size)
resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
out: torch.Tensor = deconved + resized
return out
Expand Down Expand Up @@ -372,8 +381,10 @@ def __init__(
out_activation: Optional[str] = None,
out_channels: int = 3,
pooling: bool = True,
use_addictive_sampling: bool = True,
use_additive_sampling: bool = True,
concat_skip: bool = False,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
"""
Args:
Expand All @@ -385,10 +396,14 @@ def __init__(
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
use_addictive_sampling: whether use additive up-sampling layer for decoding.
use_additive_sampling: whether use additive up-sampling layer for decoding.
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
mode: mode for interpolation when use_additive_sampling, default is "nearest".
align_corners: align_corners for interpolation when use_additive_sampling, default is None.
"""
self.use_additive_upsampling = use_addictive_sampling
self.use_additive_upsampling = use_additive_sampling
self.mode = mode
self.align_corners = align_corners
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
Expand All @@ -412,7 +427,11 @@ def build_bottom_block(self, in_channels: int, out_channels: int):
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
if self.use_additive_upsampling:
return AdditiveUpSampleBlock(
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
mode=self.mode,
align_corners=self.align_corners,
)

return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
2 changes: 2 additions & 0 deletions tests/test_localnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"extract_levels": (0, 1),
"pooling": False,
"concat_skip": True,
"mode": "bilinear",
"align_corners": True,
},
(1, 2, 16, 16),
(1, 2, 16, 16),
Expand Down
12 changes: 11 additions & 1 deletion tests/test_localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3]
]

TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]]
TEST_CASE_UP_SAMPLE = [
[
{
"spatial_dims": spatial_dims,
"in_channels": 4,
"out_channels": 2,
"mode": "bilinear" if spatial_dims == 2 else "trilinear",
}
]
for spatial_dims in [2, 3]
]

TEST_CASE_EXTRACT = [
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}]
Expand Down
1 change: 1 addition & 0 deletions tests/test_regunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"out_channels": 1,
"kernel_initializer": "zeros",
"activation": "sigmoid",
"mode": "trilinear",
},
[(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)],
(3, 3, 3),
Expand Down

0 comments on commit df6bc9c

Please sign in to comment.