-
Notifications
You must be signed in to change notification settings - Fork 0
/
Metrics.py
85 lines (66 loc) · 2.39 KB
/
Metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Iterable
import torch
from sklearn.metrics import confusion_matrix
class Metric:
def __call__(self, *args, **kwargs):
raise NotImplementedError
def get_metric(self, reset: bool = False) -> float:
raise NotImplementedError
@staticmethod
def detach_tensors(*tensors: torch.Tensor) -> Iterable[torch.Tensor]:
return (x.detach() if isinstance(x, torch.Tensor) else x for x in tensors)
class Average(Metric):
def __init__(self):
self._count = 0
self._sum = 0
def __call__(self, value: torch.Tensor):
(value,) = self.detach_tensors(value)
self._count += 1
self._sum += value
def get_metric(self, reset: bool = False):
if self._count == 0:
value = 0
else:
value = self._sum / self._count
if reset:
self._count = 0
self._sum = 0
return value
class Accuracy(Metric):
def __init__(self, ignored_label: int = -1):
self._total = 0
self._correct = 0
self._ignored_label = ignored_label
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
prediction, target = self.detach_tensors(prediction, target)
assert len(prediction) == len(target)
self._total += target.ne(self._ignored_label).sum()
self._correct += (prediction == target).sum()
def get_metric(self, reset: bool = False):
if self._total == 0:
value = 0
else:
value = self._correct / self._total
if reset:
self._total = 0
self._correct = 0
return value
class ConfusionMatrix(Metric):
def __init__(self, ignored_label: int = -1):
self._prediction = []
self._target = []
self._ignored_label = ignored_label
def __call__(self, prediction: torch.Tensor, target: torch.Tensor):
prediction, target = self.detach_tensors(prediction, target)
assert len(prediction) == len(target)
self._prediction.extend(prediction.tolist())
self._target.extend(target.tolist())
def get_metric(self, reset: bool = False):
if len(self._prediction) == 0:
value = ""
else:
value = confusion_matrix(y_true=self._target, y_pred=self._prediction)
if reset:
self._prediction = []
self._target = []
return value