Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu committed Oct 24, 2023
1 parent e7feedf commit 42aa088
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
15 changes: 8 additions & 7 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.weight = weight
self.register_buffer("class_weight", torch.ones(1))
if weight is not None:
weight = torch.as_tensor(weight)
self.register_buffer("class_weight", weight)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -189,13 +190,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:

f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)

if self.weight is not None and target.shape[1] != 1:
num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
num_of_classes = target.shape[1]
if isinstance(self.weight, (float, int)):
self.class_weight = torch.as_tensor([self.weight] * num_of_classes)
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
self.class_weight = torch.as_tensor(self.weight)
self.class_weight = torch.as_tensor(self.class_weight)
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
Expand Down
14 changes: 8 additions & 6 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def __init__(
self.alpha = alpha
self.weight = weight
self.use_softmax = use_softmax
self.register_buffer("class_weight", torch.ones(1))
if weight is not None:
weight = torch.as_tensor(weight)
self.register_buffer("class_weight", weight)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -162,13 +164,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
else:
loss = sigmoid_focal_loss(input, target, self.gamma, self.alpha)

if self.weight is not None:
num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
# make sure the lengths of weights are equal to the number of classes
num_of_classes = target.shape[1]
if isinstance(self.weight, (float, int)):
self.class_weight = torch.as_tensor([self.weight] * num_of_classes)
if self.class_weight.ndim == 0:
self.class_weight = torch.as_tensor([self.class_weight] * num_of_classes)
else:
self.class_weight = torch.as_tensor(self.weight)
self.class_weight = torch.as_tensor(self.class_weight)
if self.class_weight.shape[0] != num_of_classes:
raise ValueError(
"""the length of the `weight` sequence should be the same as the number of classes.
Expand Down

0 comments on commit 42aa088

Please sign in to comment.