Skip to content

Commit

Permalink
improve algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
qingpeng9802 committed May 23, 2023
1 parent a97cf1f commit 5db9e58
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def softmax_focal_loss(
where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
s_j is the unnormalized score for class j.
"""
pt = input.softmax(1)
loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target
input_ls = input.log_softmax(1)
loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target

if alpha is not None:
# (1-alpha) for the background class and alpha for the other classes
Expand Down

0 comments on commit 5db9e58

Please sign in to comment.