-
Notifications
You must be signed in to change notification settings - Fork 20
/
loss.py
664 lines (547 loc) · 20.8 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
"""
torch_ecg.models.loss
=====================
Custom loss functions for ECG analysis, as a complement to
built-in loss functions in PyTorch.
.. contents:: torch_ecg.models
:depth: 1
:local:
:backlinks: top
.. currentmodule:: torch_ecg.models.loss
.. autosummary::
:toctree: generated/
:recursive:
WeightedBCELoss
BCEWithLogitsWithClassWeightLoss
MaskedBCEWithLogitsLoss
FocalLoss
AsymmetricLoss
setup_criterion
"""
from numbers import Real
from typing import Any, Literal, Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
__all__ = [
"WeightedBCELoss",
"BCEWithLogitsWithClassWeightLoss",
"MaskedBCEWithLogitsLoss",
"FocalLoss",
"AsymmetricLoss",
"setup_criterion",
]
def weighted_binary_cross_entropy(
sigmoid_x: Tensor,
targets: Tensor,
pos_weight: Tensor,
weight: Optional[Tensor] = None,
size_average: bool = True,
reduce: bool = True,
) -> Tensor:
"""Weighted Binary Cross Entropy Loss function.
This implementation is based on [#wbce]_.
Parameters
----------
sigmoid_x : torch.Tensor
Predicted probability of size ``[N, C]``, N sample and C Class.
Eg. Must be in range of ``[0, 1]``,
i.e. output from :class:`~torch.nn.Sigmoid`.
targets : torch.Tensor
True value, one-hot-like vector of size ``[N, C]``.
pos_weight : torch.Tensor
Weight for postive sample.
weight : torch.Tensor, optional
Weight for each class, of size ``[1, C]``.
size_average : bool, default True
If True, the losses are averaged
over each loss element in the batch.
Valid only if `reduce` is True.
reduce : bool, default True
If True, the losses are averaged or summed
over observations for each minibatch.
Returns
-------
loss : torch.Tensor
Weighted Binary Cross Entropy Loss.
References
----------
.. [#wbce] https://github.com/pytorch/pytorch/issues/5660#issuecomment-403770305
"""
if not (targets.size() == sigmoid_x.size()):
raise ValueError(f"Target size ({targets.size()}) must be the same as input size ({sigmoid_x.size()})")
loss = -pos_weight * targets * sigmoid_x.log() - (1 - targets) * (1 - sigmoid_x).log()
if weight is not None:
loss = loss * weight
if not reduce:
return loss
elif size_average:
return loss.mean()
else:
return loss.sum()
class WeightedBCELoss(nn.Module):
"""Weighted Binary Cross Entropy Loss class.
This implementation is based on [#wbce]_.
Parameters
----------
pos_weight : torch.Tensor
Weight for postive sample.
weight : torch.Tensor, optional
Weight for each class, of size ``[1, C]``.
PosWeightIsDynamic : bool, default False
If True, the pos_weight is computed on each batch.
If `pos_weight` is None, then it remains None.
WeightIsDynamic : bool, default False
If True, the weight is computed on each batch.
If `weight` is None, then it remains None.
size_average : bool, default True
If True, the losses are averaged
over each loss element in the batch.
Valid only if `reduce` is True.
reduce : bool, default True
If True, the losses are averaged or summed
over observations for each minibatch.
References
----------
.. [#wbce] https://github.com/pytorch/pytorch/issues/5660#issuecomment-403770305
"""
__name__ = "WeightedBCELoss"
def __init__(
self,
pos_weight: Tensor,
weight: Optional[Tensor] = None,
PosWeightIsDynamic: bool = False,
WeightIsDynamic: bool = False,
size_average: bool = True,
reduce: bool = True,
) -> None:
super().__init__()
self.register_buffer("pos_weight", pos_weight)
if weight is None:
weight = torch.ones_like(pos_weight)
self.register_buffer("weight", weight)
self.size_average = size_average
self.reduce = reduce
self.PosWeightIsDynamic = PosWeightIsDynamic
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
The predicted probability tensor,
of shape ``(batch_size, ..., n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, ..., n_classes)``.
Returns
-------
loss : torch.Tensor
The weighted binary cross entropy loss.
"""
if self.PosWeightIsDynamic:
positive_counts = target.sum(dim=0, keepdim=True)
nBatch = len(target)
self.pos_weight = (nBatch - positive_counts) / (positive_counts + 1e-7)
return weighted_binary_cross_entropy(
input,
target,
pos_weight=self.pos_weight,
weight=self.weight,
size_average=self.size_average,
reduce=self.reduce,
)
class BCEWithLogitsWithClassWeightLoss(nn.BCEWithLogitsLoss):
"""Class-weighted Binary Cross Entropy Loss class.
Parameters
----------
class_weight : torch.Tensor
Class weight, of shape ``(1, n_classes)``.
"""
__name__ = "BCEWithLogitsWithClassWeightLoss"
def __init__(self, class_weight: Tensor) -> None:
super().__init__(reduction="none")
self.register_buffer("class_weight", class_weight)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
The predicted value tensor (before sigmoid),
of shape ``(batch_size, ..., n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, ..., n_classes)``.
Returns
-------
torch.Tensor
The class-weighted binary cross entropy loss.
"""
loss = super().forward(input, target)
loss = torch.mean(loss * self.class_weight)
return loss
class MaskedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
"""Masked Binary Cross Entropy Loss class.
This loss is used mainly for the segmentation task, where
there are some regions that are of much higher importance,
for example, the onsets and offsets of some particular events
(e.g. paroxysmal atrial fibrillation (AF) episodes).
This loss is proposed in [#mbce]_, with a reference to the loss
function used in the U-Net paper [#unet]_.
References
----------
.. [#mbce] Wen, Hao, and Jingsu Kang. "A comparative study on neural networks for
paroxysmal atrial fibrillation events detection from electrocardiography."
Journal of Electrocardiology 75 (2022): 19-27.
.. [#unet] Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional
networks for biomedical image segmentation." International Conference on
Medical image computing and computer-assisted intervention. Springer, Cham,
2015.
"""
__name__ = "MaskedBCEWithLogitsLoss"
def __init__(self) -> None:
super().__init__(reduction="none")
def forward(self, input: Tensor, target: Tensor, weight_mask: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
The predicted value tensor (before sigmoid),
of shape ``(batch_size, sig_len, n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, sig_len, n_classes)``.
weight_mask: torch.Tensor
The weight mask tensor,
of shape ``(batch_size, sig_len, n_classes)``.
Returns
-------
torch.Tensor
The masked binary cross entropy loss.
NOTE
----
`input`, `target`, and `weight_mask` should be
3-D tensors of the same shape.
"""
loss = super().forward(input, target)
loss = torch.mean(loss * weight_mask)
return loss
class FocalLoss(nn.modules.loss._WeightedLoss):
"""Focal loss class.
The focal loss is proposed in [1]_, and this implementation is
based on [2]_, [3]_, and [4]_. The focal loss is computed as follows:
.. math::
\\operatorname{FL}(p_t) = -\\alpha_t (1 - p_t)^{\\gamma} \\, \\log(p_t)
Where:
- :math:`p_t` is the model's estimated probability for each class.
Parameters
----------
gamma : float, default 2.0
The gamma parameter of focal loss.
weight : torch.Tensor, optional
If `multi_label` is True,
is a manual rescaling weight given to the loss of each batch element,
of size ``batch_size``;
if `multi_label` is False,
is a weight for each class, of size ``n_classes``.
class_weight : torch.Tensor, optional
The class weight, of shape ``(1, n_classes)``.
size_average : bool, optional
Not used, to keep in accordance with PyTorch native loss.
reduce : bool, optional
Not used, to keep in accordance with PyTorch native loss.
reduction: {"none", "mean", "sum"}, optional
Specifies the reduction to apply to the output, by default "mean".
multi_label : bool, default True
If True, the loss is computed for multi-label classification.
References
----------
.. [1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection."
Proceedings of the IEEE international conference on computer vision. 2017.
.. [2] https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py
.. [3] https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
.. [4] https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327
"""
__name__ = "FocalLoss"
def __init__(
self,
gamma: float = 2.0,
weight: Optional[Tensor] = None,
class_weight: Optional[Tensor] = None, # alpha
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
multi_label: bool = True,
**kwargs: Any,
) -> None:
if multi_label or weight is not None:
w = weight
else:
w = class_weight
if not multi_label and w.ndim == 2:
w = w.squeeze(0)
super().__init__(weight=w, size_average=size_average, reduce=reduce, reduction=reduction)
# In practice `alpha` may be set by inverse class frequency or treated as a hyperparameter
# the `class_weight` are usually inverse class frequencies
# self.alpha = alpha
self.gamma = gamma
if multi_label:
self.entropy_func = F.binary_cross_entropy_with_logits
# for `binary_cross_entropy_with_logits`,
# its parameter `weight` is a manual rescaling weight given to the loss of each batch element
self.register_buffer("class_weight", class_weight)
else:
self.entropy_func = F.cross_entropy
# for `cross_entropy`,
# its parameter `weight` is a manual rescaling weight given to each class
self.class_weight = None
@property
def alpha(self) -> Tensor:
return self.class_weight
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
The predicted value tensor (before sigmoid),
of shape ``(batch_size, n_classes)``.
target : torch.Tensor
Multi-label binarized vector of shape ``(batch_size, n_classes)``,
or single label binarized vector of shape ``(batch_size,)``.
Returns
-------
torch.Tensor
The focal loss.
"""
entropy = self.entropy_func(
input,
target,
weight=self.weight,
reduction="none",
)
p_t = torch.exp(-entropy)
fl = torch.pow(1 - p_t, self.gamma) * entropy
if self.class_weight is not None:
fl = fl * self.class_weight
if self.reduction == "mean":
fl = fl.mean()
elif self.reduction == "sum":
fl = fl.sum()
return fl
class AsymmetricLoss(nn.Module):
"""Asymmetric loss class.
The asymmetric loss is proposed in [#al]_, with official
implementation in [#al_code]_. The asymmetric loss is defined as
.. math::
ASL = \\begin{cases}
L_+ := (1-p)^{\\gamma_+} \\log(p) \\
L_- := (p_m)^{\\gamma_-} \\log(1-p_m)
\\end{cases}
where :math:`p_m = \\max(p-m, 0)` is the shifted probability,
with probability margin :math:`m`.
The loss on one label of one sample is
.. math::
L = -yL_+ - (1-y)L_-
Parameters
----------
gamma_neg : numbers.Real, default 4
Exponent of the multiplier to the negative loss.
gamma_pos : numbers.Real, default 1
Exponent of the multiplier to the positive loss.
prob_margin : float, default 0.05
The probability margin
disable_torch_grad_focal_loss : bool, default False
If True, disable :func:`torch.grad` for asymmetric focal loss computing.
reduction : {"none", "mean", "sum"}, default "mean"
Specifies the reduction to apply to the output.
implementation : {"alibaba-miil", "deep-psp"}, default "alibaba-miil"
Implementation by Alibaba-MIIL, or by `DeepPSP`, case insensitive.
NOTE
----
Since :class:`AsymmetricLoss` aims at emphasizing the contribution of positive samples,
`gamma_neg` is usually greater than `gamma_pos`.
TODO
----
1. Evaluate the settings that `gamma_neg`, `gamma_pos` are tensors,
of shape ``(1, n_classes)``, in which case we would have one ratio
of positive to negative for each class.
References
----------
.. [#al] Ridnik, Tal, et al. "Asymmetric Loss for Multi-Label Classification."
Proceedings of the IEEE/CVF International Conference on Computer Vision. 2021.
.. [#al_code] https://github.com/Alibaba-MIIL/ASL/
"""
__name__ = "AsymmetricLoss"
def __init__(
self,
gamma_neg: Real = 4,
gamma_pos: Real = 1,
prob_margin: float = 0.05,
disable_torch_grad_focal_loss: bool = False,
reduction: Literal["none", "mean", "sum"] = "mean",
implementation: Literal["alibaba-miil", "deep-psp"] = "alibaba-miil",
) -> None:
super().__init__()
self.implementation = implementation.lower()
assert self.implementation in [
"alibaba-miil",
"deep-psp",
"deeppsp",
]
self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.prob_margin = prob_margin
if self.prob_margin < 0:
raise ValueError("`prob_margin` must be non-negative")
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = 1e-8
self.reduction = reduction.lower()
if self.implementation == "alibaba-miil":
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None
elif self.implementation in [
"deep-psp",
"deeppsp",
]:
self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.loss = self.loss_pos = self.loss_neg = None
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass.
Parameters
----------
input : torch.Tensor
The predicted value tensor,
of shape ``(batch_size, n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, n_classes)``.
Returns
-------
torch.Tensor
The asymmetric loss.
"""
if self.implementation == "alibaba-miil":
return self._forward_alibaba_miil(input, target)
else:
return self._forward_deep_psp(input, target)
def _forward_deep_psp(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass of DeepPSP implementation.
Parameters
----------
input : torch.Tensor
The predicted value tensor,
of shape ``(batch_size, n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, n_classes)``.
Returns
-------
torch.Tensor
The asymmetric loss.
"""
self.targets = target
self.anti_targets = 1 - target
# Calculating Probabilities
self.xs_pos = torch.sigmoid(input)
self.xs_neg = 1.0 - self.xs_pos
# Asymmetric Clipping
if self.prob_margin > 0:
self.xs_neg.add_(self.prob_margin).clamp_(max=1)
# Basic CE calculation
self.loss_pos = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss_neg = self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
self.loss_pos *= torch.pow(1 - self.xs_pos, self.gamma_pos)
self.loss_neg *= torch.pow(self.xs_pos, self.gamma_neg)
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(prev)
self.loss = -self.loss_pos - self.loss_neg
if self.reduction == "mean":
self.loss = self.loss.mean()
elif self.reduction == "sum":
self.loss = self.loss.sum()
return self.loss
def _forward_alibaba_miil(self, input: Tensor, target: Tensor) -> Tensor:
"""Forward pass of Alibaba MIIL implementation.
Parameters
----------
input : torch.Tensor
The predicted value tensor,
of shape ``(batch_size, n_classes)``.
target : torch.Tensor
The target tensor,
of shape ``(batch_size, n_classes)``.
Returns
-------
torch.Tensor
The asymmetric loss.
"""
self.targets = target
self.anti_targets = 1 - target
# Calculating Probabilities
self.xs_pos = torch.sigmoid(input)
self.xs_neg = 1.0 - self.xs_pos
# Asymmetric Clipping
if self.prob_margin > 0:
self.xs_neg.add_(self.prob_margin).clamp_(max=1)
# Basic CE calculation
# loss = y * log(p) + (1-y) * log(1-p)
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
self.xs_pos = self.xs_pos * self.targets # p * y
self.xs_neg = self.xs_neg * self.anti_targets # (1-p) * (1-y)
self.asymmetric_w = torch.pow(
1 - self.xs_pos - self.xs_neg,
self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets,
)
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(prev)
self.loss *= self.asymmetric_w
if self.reduction == "mean":
self.loss = -self.loss.mean()
elif self.reduction == "sum":
self.loss = -self.loss.sum()
else:
self.loss = -self.loss
return self.loss
def setup_criterion(name: str, **kwargs: Any) -> nn.Module:
"""Setup the criterion (loss function).
Parameters
----------
name : str
The name of the criterion.
**kwargs : Any
The keyword arguments for the criterion.
Returns
-------
nn.Module
The criterion (loss function).
"""
if name == "WeightedBCELoss":
criterion = WeightedBCELoss(**kwargs)
elif name == "BCEWithLogitsWithClassWeightLoss":
criterion = BCEWithLogitsWithClassWeightLoss(**kwargs)
elif name == "MaskedBCEWithLogitsLoss":
criterion = MaskedBCEWithLogitsLoss(**kwargs)
elif name == "FocalLoss":
criterion = FocalLoss(**kwargs)
elif name == "AsymmetricLoss":
criterion = AsymmetricLoss(**kwargs)
elif name.startswith("nn."):
criterion = eval(name)(**kwargs)
elif name in nn.modules.loss.__all__:
criterion = getattr(nn, name)(**kwargs)
else:
raise NotImplementedError(
f"loss `{name}` not implemented! "
"Please use one of the following: `BCEWithLogitsWithClassWeightLoss`, "
"`MaskedBCEWithLogitsLoss`, `FocalLoss`, `AsymmetricLoss`, "
"or any loss function in `torch.nn`."
)
return criterion