Skip to content

Commit

Permalink
Update nacl_loss.py
Browse files Browse the repository at this point in the history
Call contiguous after permute to avoid reshaping issue.
  • Loading branch information
Bala93 authored Aug 15, 2024
1 parent db9daeb commit f579763
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monai/losses/nacl_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
rmask: torch.Tensor

if self.dim == 2:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 3, 1, 2).float()
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()
rmask = self.svls_layer(oh_labels)

if self.dim == 3:
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).contiguous().permute(0, 4, 1, 2, 3).float()
oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()
rmask = self.svls_layer(oh_labels)

return rmask
Expand Down

0 comments on commit f579763

Please sign in to comment.