-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
dice.py
995 lines (853 loc) · 46.3 KB
/
dice.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from collections.abc import Callable, Sequence
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.networks import one_hot
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
class DiceLoss(_Loss):
"""
Compute average Dice loss between two tensors. It can support both multi-classes and multi-labels tasks.
The data `input` (BNHW[D] where N is number of classes) is compared with ground truth `target` (BNHW[D]).
Note that axis N of `input` is expected to be logits or probabilities for each class, if passing logits as input,
must set `sigmoid=True` or `softmax=True`, or specifying `other_act`. And the same axis of `target`
can be 1 or N (one-hot format).
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
the inter-over-union calculation to smooth results respectively, these values should be small.
The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
Medical Image Segmentation, 3DV, 2016.
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
) -> None:
"""
Args:
include_background: if False, channel index 0 (background category) is excluded from the calculation.
if the non-background segmentations are small compared to the total image size they can get overwhelmed
by the signal from the background so excluding it in such cases helps convergence.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction.
softmax: if True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
"""
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.squared_pred = squared_pred
self.jaccard = jaccard
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
Raises:
AssertionError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
Example:
>>> from monai.losses.dice import * # NOQA
>>> import torch
>>> from monai.losses.dice import DiceLoss
>>> B, C, H, W = 7, 5, 3, 2
>>> input = torch.rand(B, C, H, W)
>>> target_idx = torch.randint(low=0, high=C - 1, size=(B, H, W)).long()
>>> target = one_hot(target_idx[:, None, ...], num_classes=C)
>>> self = DiceLoss(reduction='none')
>>> loss = self(input, target)
>>> assert np.broadcast_shapes(loss.shape, input.shape) == input.shape
"""
if self.sigmoid:
input = torch.sigmoid(input)
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
else:
input = torch.softmax(input, 1)
if self.other_act is not None:
input = self.other_act(input)
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]
if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis
intersection = torch.sum(target * input, dim=reduce_axis)
if self.squared_pred:
ground_o = torch.sum(target**2, dim=reduce_axis)
pred_o = torch.sum(input**2, dim=reduce_axis)
else:
ground_o = torch.sum(target, dim=reduce_axis)
pred_o = torch.sum(input, dim=reduce_axis)
denominator = ground_o + pred_o
if self.jaccard:
denominator = 2.0 * (denominator - intersection)
f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
if self.reduction == LossReduction.MEAN.value:
f = torch.mean(f) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
f = torch.sum(f) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
f = f.view(broadcast_shape)
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return f
class MaskedDiceLoss(DiceLoss):
"""
Add an additional `masking` process before `DiceLoss`, accept a binary mask ([0, 1]) indicating a region,
`input` and `target` will be masked by the region: region with mask `1` will keep the original value,
region with `0` mask will be converted to `0`. Then feed `input` and `target` to normal `DiceLoss` computation.
This has the effect of ensuring only the masked region contributes to the loss computation and
hence gradient calculation.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Args follow :py:class:`monai.losses.DiceLoss`.
"""
super().__init__(*args, **kwargs)
self.spatial_weighted = MaskedLoss(loss=super().forward)
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
target: the shape should be BNH[WD].
mask: the shape should B1H[WD] or 11H[WD].
"""
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
class GeneralizedDiceLoss(_Loss):
"""
Compute the generalised Dice loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017.
Adapted from:
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L279
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
w_type: Weight | str = Weight.SQUARE,
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
) -> None:
"""
Args:
include_background: If False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: If True, apply a sigmoid function to the prediction.
softmax: If True, apply a softmax function to the prediction.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``.
w_type: {``"square"``, ``"simple"``, ``"uniform"``}
Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, intersection over union is computed from each item in the batch.
If True, the class-weighted intersection and union areas are first summed across the batches.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
Incompatible values.
"""
super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) + int(other_act is not None) > 1:
raise ValueError("Incompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].")
self.include_background = include_background
self.to_onehot_y = to_onehot_y
self.sigmoid = sigmoid
self.softmax = softmax
self.other_act = other_act
self.w_type = look_up_option(w_type, Weight)
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
def w_func(self, grnd):
if self.w_type == str(Weight.SIMPLE):
return torch.reciprocal(grnd)
if self.w_type == str(Weight.SQUARE):
return torch.reciprocal(grnd * grnd)
return torch.ones_like(grnd)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
target: the shape should be BNH[WD].
Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
"""
if self.sigmoid:
input = torch.sigmoid(input)
n_pred_ch = input.shape[1]
if self.softmax:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `softmax=True` ignored.")
else:
input = torch.softmax(input, 1)
if self.other_act is not None:
input = self.other_act(input)
if self.to_onehot_y:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)
if not self.include_background:
if n_pred_ch == 1:
warnings.warn("single channel prediction, `include_background=False` ignored.")
else:
# if skipping background, removing first channel
target = target[:, 1:]
input = input[:, 1:]
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")
# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
reduce_axis = [0] + reduce_axis
intersection = torch.sum(target * input, reduce_axis)
ground_o = torch.sum(target, reduce_axis)
pred_o = torch.sum(input, reduce_axis)
denominator = ground_o + pred_o
w = self.w_func(ground_o.float())
infs = torch.isinf(w)
if self.batch:
w[infs] = 0.0
w = w + infs * torch.max(w)
else:
w[infs] = 0.0
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
w = w + infs * max_values
final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)
if self.reduction == LossReduction.MEAN.value:
f = torch.mean(f) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
f = torch.sum(f) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = list(f.shape[0:2]) + [1] * (len(input.shape) - 2)
f = f.view(broadcast_shape)
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return f
class GeneralizedWassersteinDiceLoss(_Loss):
"""
Compute the generalized Wasserstein Dice Loss defined in:
Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class
Segmentation using Holistic Convolutional Networks. BrainLes 2017.
Or its variant (use the option weighting_mode="GDL") defined in the Appendix of:
Tilborghs, S. et al. (2020) Comparative study of deep learning methods for the automatic
segmentation of lung, lesion and lesion type in CT scans of COVID-19 patients.
arXiv preprint arXiv:2007.15546
Adapted from:
https://github.com/LucasFidon/GeneralizedWassersteinDiceLoss
"""
def __init__(
self,
dist_matrix: np.ndarray | torch.Tensor,
weighting_mode: str = "default",
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
) -> None:
"""
Args:
dist_matrix: 2d tensor or 2d numpy array; matrix of distances between the classes.
It must have dimension C x C where C is the number of classes.
weighting_mode: {``"default"``, ``"GDL"``}
Specifies how to weight the class-specific sum of errors.
Default to ``"default"``.
- ``"default"``: (recommended) use the original weighting method as in:
Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class
Segmentation using Holistic Convolutional Networks. BrainLes 2017.
- ``"GDL"``: use a GDL-like weighting method as in the Appendix of:
Tilborghs, S. et al. (2020) Comparative study of deep learning methods for the automatic
segmentation of lung, lesion and lesion type in CT scans of COVID-19 patients.
arXiv preprint arXiv:2007.15546
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
Raises:
ValueError: When ``dist_matrix`` is not a square matrix.
Example:
.. code-block:: python
import torch
import numpy as np
from monai.losses import GeneralizedWassersteinDiceLoss
# Example with 3 classes (including the background: label 0).
# The distance between the background class (label 0) and the other classes is the maximum, equal to 1.
# The distance between class 1 and class 2 is 0.5.
dist_mat = np.array([[0.0, 1.0, 1.0], [1.0, 0.0, 0.5], [1.0, 0.5, 0.0]], dtype=np.float32)
wass_loss = GeneralizedWassersteinDiceLoss(dist_matrix=dist_mat)
pred_score = torch.tensor([[1000, 0, 0], [0, 1000, 0], [0, 0, 1000]], dtype=torch.float32)
grnd = torch.tensor([0, 1, 2], dtype=torch.int64)
wass_loss(pred_score, grnd) # 0
"""
super().__init__(reduction=LossReduction(reduction).value)
if dist_matrix.shape[0] != dist_matrix.shape[1]:
raise ValueError(f"dist_matrix must be C x C, got {dist_matrix.shape[0]} x {dist_matrix.shape[1]}.")
if weighting_mode not in ["default", "GDL"]:
raise ValueError("weighting_mode must be either 'default' or 'GDL, got %s." % weighting_mode)
self.m = dist_matrix
if isinstance(self.m, np.ndarray):
self.m = torch.from_numpy(self.m)
if torch.max(self.m) != 1:
self.m = self.m / torch.max(self.m)
self.alpha_mode = weighting_mode
self.num_classes = self.m.size(0)
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
target: the shape should be BNH[WD].
"""
# Aggregate spatial dimensions
flat_input = input.reshape(input.size(0), input.size(1), -1)
flat_target = target.reshape(target.size(0), -1).long()
# Apply the softmax to the input scores map
probs = F.softmax(flat_input, dim=1)
# Compute the Wasserstein distance map
wass_dist_map = self.wasserstein_distance_map(probs, flat_target)
# Compute the values of alpha to use
alpha = self._compute_alpha_generalized_true_positives(flat_target)
# Compute the numerator and denominator of the generalized Wasserstein Dice loss
if self.alpha_mode == "GDL":
# use GDL-style alpha weights (i.e. normalize by the volume of each class)
# contrary to the original definition we also use alpha in the "generalized all error".
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
denom = self._compute_denominator(alpha, flat_target, wass_dist_map)
else: # default: as in the original paper
# (i.e. alpha=1 for all foreground classes and 0 for the background).
# Compute the generalised number of true positives
true_pos = self._compute_generalized_true_positive(alpha, flat_target, wass_dist_map)
all_error = torch.sum(wass_dist_map, dim=1)
denom = 2 * true_pos + all_error
# Compute the final loss
wass_dice: torch.Tensor = (2.0 * true_pos + self.smooth_nr) / (denom + self.smooth_dr)
wass_dice_loss: torch.Tensor = 1.0 - wass_dice
if self.reduction == LossReduction.MEAN.value:
wass_dice_loss = torch.mean(wass_dice_loss) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
wass_dice_loss = torch.sum(wass_dice_loss) # sum over the batch and channel dims
elif self.reduction == LossReduction.NONE.value:
# If we are not computing voxelwise loss components at least
# make sure a none reduction maintains a broadcastable shape
broadcast_shape = input.shape[0:2] + (1,) * (len(input.shape) - 2)
wass_dice_loss = wass_dice_loss.view(broadcast_shape)
else:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
return wass_dice_loss
def wasserstein_distance_map(self, flat_proba: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor:
"""
Compute the voxel-wise Wasserstein distance between the
flattened prediction and the flattened labels (ground_truth) with respect
to the distance matrix on the label space M.
This corresponds to eq. 6 in:
Fidon L. et al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class
Segmentation using Holistic Convolutional Networks. BrainLes 2017.
Args:
flat_proba: the probabilities of input(predicted) tensor.
flat_target: the target tensor.
"""
# Turn the distance matrix to a map of identical matrix
m = torch.clone(torch.as_tensor(self.m)).to(flat_proba.device)
m_extended = torch.unsqueeze(m, dim=0)
m_extended = torch.unsqueeze(m_extended, dim=3)
m_extended = m_extended.expand((flat_proba.size(0), m_extended.size(1), m_extended.size(2), flat_proba.size(2)))
# Expand the feature dimensions of the target
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
flat_target_extended = flat_target_extended.expand(
(flat_target.size(0), m_extended.size(1), flat_target.size(1))
)
flat_target_extended = torch.unsqueeze(flat_target_extended, dim=1)
# Extract the vector of class distances for the ground-truth label at each voxel
m_extended = torch.gather(m_extended, dim=1, index=flat_target_extended)
m_extended = torch.squeeze(m_extended, dim=1)
# Compute the wasserstein distance map
wasserstein_map = m_extended * flat_proba
# Sum over the classes
wasserstein_map = torch.sum(wasserstein_map, dim=1)
return wasserstein_map
def _compute_generalized_true_positive(
self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
) -> torch.Tensor:
"""
Args:
alpha: generalised number of true positives of target class.
flat_target: the target tensor.
wasserstein_distance_map: the map obtained from the above function.
"""
# Extend alpha to a map and select value at each voxel according to flat_target
alpha_extended = torch.unsqueeze(alpha, dim=2)
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=[1, 2])
def _compute_denominator(
self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
) -> torch.Tensor:
"""
Args:
alpha: generalised number of true positives of target class.
flat_target: the target tensor.
wasserstein_distance_map: the map obtained from the above function.
"""
# Extend alpha to a map and select value at each voxel according to flat_target
alpha_extended = torch.unsqueeze(alpha, dim=2)
alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
flat_target_extended = torch.unsqueeze(flat_target, dim=1)
alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=[1, 2])
def _compute_alpha_generalized_true_positives(self, flat_target: torch.Tensor) -> torch.Tensor:
"""
Args:
flat_target: the target tensor.
"""
alpha: torch.Tensor = torch.ones((flat_target.size(0), self.num_classes)).float().to(flat_target.device)
if self.alpha_mode == "GDL": # GDL style
# Define alpha like in the generalized dice loss
# i.e. the inverse of the volume of each class.
one_hot_f = F.one_hot(flat_target, num_classes=self.num_classes).permute(0, 2, 1).float()
volumes = torch.sum(one_hot_f, dim=2)
alpha = 1.0 / (volumes + 1.0)
else: # default, i.e. like in the original paper
# alpha weights are 0 for the background and 1 the other classes
alpha[:, 0] = 0.0
return alpha
class DiceCELoss(_Loss):
"""
Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``.
In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
not supported.
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: str = "mean",
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
ce_weight: torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
) -> None:
"""
Args:
``ce_weight`` and ``lambda_ce`` are only used for cross entropy loss.
``reduction`` is used for both losses and other parameters are only used for dice loss.
include_background: if False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``. The dice loss should
as least reduce the spatial dimensions, which is different from cross entropy loss, thus here
the ``none`` option cannot be used.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`.
See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
Defaults to 1.0.
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
Defaults to 1.0.
"""
super().__init__()
reduction = look_up_option(reduction, DiceCEReduction).value
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
squared_pred=squared_pred,
jaccard=jaccard,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(weight=ce_weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_ce < 0.0:
raise ValueError("lambda_ce should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_ce = lambda_ce
self.old_pt_ver = not pytorch_after(1, 10)
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute CrossEntropy loss for the input logits and target.
Will remove the channel dim according to PyTorch CrossEntropyLoss:
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.
"""
n_pred_ch, n_target_ch = input.shape[1], target.shape[1]
if n_pred_ch != n_target_ch and n_target_ch == 1:
target = torch.squeeze(target, dim=1)
target = target.long()
elif self.old_pt_ver:
warnings.warn(
f"Multichannel targets are not supported in this older Pytorch version {torch.__version__}. "
"Using argmax (as a workaround) to convert target to a single channel."
)
target = torch.argmax(target, dim=1)
elif not torch.is_floating_point(target):
target = target.to(dtype=input.dtype)
return self.cross_entropy(input, target) # type: ignore[no-any-return]
def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute Binary CrossEntropy loss for the input logits and target in one single class.
"""
if not torch.is_floating_point(target):
target = target.to(dtype=input.dtype)
return self.binary_cross_entropy(input, target) # type: ignore[no-any-return]
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD].
target: the shape should be BNH[WD] or B1H[WD].
Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
"""
if len(input.shape) != len(target.shape):
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} and {target.shape}."
)
dice_loss = self.dice(input, target)
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
return total_loss
class DiceFocalLoss(_Loss):
"""
Compute both Dice loss and Focal Loss, and return the weighted sum of these two losses.
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
``include_background`` and ``reduction`` are used for both losses
and other parameters are only used for dice loss.
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
squared_pred: bool = False,
jaccard: bool = False,
reduction: str = "mean",
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
gamma: float = 2.0,
focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_focal: float = 1.0,
) -> None:
"""
Args:
include_background: if False channel index 0 (background category) is excluded from the calculation.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
don't need to specify activation function for `FocalLoss`.
other_act: callable function to execute other activation layers, Defaults to ``None``.
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
squared_pred: use squared versions of targets and predictions in the denominator or not.
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr: a small constant added to the numerator to avoid zero.
smooth_dr: a small constant added to the denominator to avoid nan.
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
gamma: value of the exponent gamma in the definition of the Focal loss.
focal_weight: weights to apply to the voxels of each class. If None no weights are applied.
The input can be a single value (same weight for all classes), a sequence of values (the length
of the sequence should be the same as the number of classes).
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
Defaults to 1.0.
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
Defaults to 1.0.
"""
super().__init__()
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=False,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
squared_pred=squared_pred,
jaccard=jaccard,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
)
self.focal = FocalLoss(
include_background=include_background,
to_onehot_y=False,
gamma=gamma,
weight=focal_weight,
reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
if lambda_focal < 0.0:
raise ValueError("lambda_focal should be no less than 0.0.")
self.lambda_dice = lambda_dice
self.lambda_focal = lambda_focal
self.to_onehot_y = to_onehot_y
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD]. The input should be the original logits
due to the restriction of ``monai.losses.FocalLoss``.
target: the shape should be BNH[WD] or B1H[WD].
Raises:
ValueError: When number of dimensions for input and target are different.
ValueError: When number of channels for target is neither 1 nor the same as input.
"""
if len(input.shape) != len(target.shape):
raise ValueError(
"the number of dimensions for input and target should be the same, "
f"got shape {input.shape} and {target.shape}."
)
if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
else:
target = one_hot(target, num_classes=n_pred_ch)
dice_loss = self.dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
return total_loss
class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
"""Compute both Generalized Dice Loss and Focal Loss, and return their weighted average. The details of Generalized Dice Loss
and Focal Loss are available at ``monai.losses.GeneralizedDiceLoss`` and ``monai.losses.FocalLoss``.
Args:
include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
Defaults to True.
to_onehot_y: whether to convert the ``target`` into the one-hot format,
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
other_act (Optional[Callable], optional): callable function to execute other activation layers,
Defaults to ``None``. for example: `other_act = torch.tanh`.
only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground-truth volume to a weight factor. Defaults to ``"square"``.
reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to
apply to the output. Defaults to ``"mean"``.
- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
smooth_nr (float, optional): a small constant added to the numerator to avoid zero. Defaults to 1e-5.
smooth_dr (float, optional): a small constant added to the denominator to avoid nan. Defaults to 1e-5.
batch (bool, optional): whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, i.e., the areas are computed for each item in the batch.
gamma (float, optional): value of the exponent gamma in the definition of the Focal loss. Defaults to 2.0.
focal_weight (Optional[Union[Sequence[float], float, int, torch.Tensor]], optional): weights to apply to
the voxels of each class. If None no weights are applied. The input can be a single value
(same weight for all classes), a sequence of values (the length of the sequence hould be the same as
the number of classes). Defaults to None.
lambda_gdl (float, optional): the trade-off weight value for Generalized Dice Loss. The value should be
no less than 0.0. Defaults to 1.0.
lambda_focal (float, optional): the trade-off weight value for Focal Loss. The value should be no less
than 0.0. Defaults to 1.0.
Raises:
ValueError: if either `lambda_gdl` or `lambda_focal` is less than 0.
"""
def __init__(
self,
include_background: bool = True,
to_onehot_y: bool = False,
sigmoid: bool = False,
softmax: bool = False,
other_act: Callable | None = None,
w_type: Weight | str = Weight.SQUARE,
reduction: LossReduction | str = LossReduction.MEAN,
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
gamma: float = 2.0,
focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_gdl: float = 1.0,
lambda_focal: float = 1.0,
) -> None:
super().__init__()
self.generalized_dice = GeneralizedDiceLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
sigmoid=sigmoid,
softmax=softmax,
other_act=other_act,
w_type=w_type,
reduction=reduction,
smooth_nr=smooth_nr,
smooth_dr=smooth_dr,
batch=batch,
)
self.focal = FocalLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
gamma=gamma,
weight=focal_weight,
reduction=reduction,
)
if lambda_gdl < 0.0:
raise ValueError("lambda_gdl should be no less than 0.0.")
if lambda_focal < 0.0:
raise ValueError("lambda_focal should be no less than 0.0.")
self.lambda_gdl = lambda_gdl
self.lambda_focal = lambda_focal
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): the shape should be BNH[WD]. The input should be the original logits
due to the restriction of ``monai.losses.FocalLoss``.
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].
Raises:
ValueError: When the input and target tensors have different numbers of dimensions, or the target
channel isn't either one-hot encoded or categorical with the same shape of the input.
Returns:
torch.Tensor: value of the loss.
"""
if input.dim() != target.dim():
raise ValueError(
f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
)
gdl_loss = self.generalized_dice(input, target)
focal_loss = self.focal(input, target)
total_loss: torch.Tensor = self.lambda_gdl * gdl_loss + self.lambda_focal * focal_loss
return total_loss
Dice = DiceLoss
dice_ce = DiceCELoss
dice_focal = DiceFocalLoss
generalized_dice = GeneralizedDiceLoss
generalized_dice_focal = GeneralizedDiceFocalLoss
generalized_wasserstein_dice = GeneralizedWassersteinDiceLoss