Skip to content

Commit

Permalink
Allow for defining reference grid on non-integer coordinates (#7032)
Browse files Browse the repository at this point in the history
This is a non-breaking change that is disabled by default. It allows for
defining the reference identity grid on non-integer values. This can be
beneficial for registration applications as shown in: B. Likar and F.
Pernus. A heirarchical approach to elastic registration based on mutual
information. Image and Vision Computing, 19:33-44, 2001.

cc: @nvahmadi

---------

Signed-off-by: Mikael Brudfors <mbrudfors@nvidia.com>
  • Loading branch information
brudfors authored Sep 22, 2023
1 parent bfabad7 commit 3bea5cf
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Warp(nn.Module):
Warp an image with given dense displacement field (DDF).
"""

def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value):
def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePadMode.BORDER.value, jitter=False):
"""
For pytorch native APIs, the possible values are:
Expand All @@ -47,6 +47,11 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
- padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ...
See also: :py:class:`monai.networks.layers.grid_pull`
- jitter: bool, default=False
Define reference grid on non-integer values
Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration
based on mutual information. Image and Vision Computing, 19:33-44, 2001.
"""
super().__init__()
# resolves _interp_mode for different methods
Expand Down Expand Up @@ -84,8 +89,9 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
self._padding_mode = GridSamplePadMode(padding_mode).value

self.ref_grid = None
self.jitter = jitter

def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor:
def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int = 0) -> torch.Tensor:
if (
self.ref_grid is not None
and self.ref_grid.shape[0] == ddf.shape[0]
Expand All @@ -96,6 +102,11 @@ def get_reference_grid(self, ddf: torch.Tensor) -> torch.Tensor:
grid = torch.stack(meshgrid_ij(*mesh_points), dim=0) # (spatial_dims, ...)
grid = torch.stack([grid] * ddf.shape[0], dim=0) # (batch, spatial_dims, ...)
self.ref_grid = grid.to(ddf)
if jitter:
# Define reference grid on non-integer values
with torch.random.fork_rng(enabled=seed):
torch.random.manual_seed(seed)
grid += torch.rand_like(grid)
self.ref_grid.requires_grad = False
return self.ref_grid

Expand All @@ -117,7 +128,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
f"Given input {spatial_dims}-d image shape {image.shape}, the input DDF shape must be {ddf_shape}, "
f"Got {ddf.shape} instead."
)
grid = self.get_reference_grid(ddf) + ddf
grid = self.get_reference_grid(ddf, jitter=self.jitter) + ddf
grid = grid.permute([0] + list(range(2, 2 + spatial_dims)) + [1]) # (batch, ..., spatial_dims)

if not USE_COMPILED: # pytorch native grid_sample
Expand Down

0 comments on commit 3bea5cf

Please sign in to comment.