From 5db9e58b8b6cf0e2323e938fc8ef125df3d8a5d0 Mon Sep 17 00:00:00 2001 From: Qingpeng Li Date: Wed, 24 May 2023 03:26:11 +0800 Subject: [PATCH] improve algorithm --- monai/losses/focal_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 039cc8b05f..5c9bfd5cd1 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -212,8 +212,8 @@ def softmax_focal_loss( where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and s_j is the unnormalized score for class j. """ - pt = input.softmax(1) - loss: torch.Tensor = - (1 - pt).pow(gamma) * input.log_softmax(1) * target + input_ls = input.log_softmax(1) + loss: torch.Tensor = - (1 - input_ls.exp()).pow(gamma) * input_ls * target if alpha is not None: # (1-alpha) for the background class and alpha for the other classes