-
Notifications
You must be signed in to change notification settings - Fork 101
/
metric.py
185 lines (140 loc) · 5.77 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from typing import List, Optional, Any, Dict
import math
from accelerate import Accelerator
import torch
from torch.utils.tensorboard import SummaryWriter
class Metric:
def __init__(self):
pass
def add(self, val):
raise NotImplementedError
def val(self) -> float:
raise NotImplementedError
def reset(self):
raise NotImplementedError
def compute(self, val: Any):
return val
def __add__(self, other):
raise NotImplementedError
def __radd__(self, other):
return self.__add__(other)
class MeanMetric(Metric):
def __init__(self, num=0, denom=0):
self.numerator = num
self.denominator: int = denom
def add(self, val: Any):
self.numerator += self.compute(val)
self.denominator += 1
def many(self, vals: List[Any], denoms: Optional[List[int]] = None):
if denoms is None:
denoms = [1] * len(vals)
assert len(vals) == len(denoms)
for v, n in zip(vals, denoms):
self.numerator += self.compute(v)
self.denominator += n
def val(self):
if self.denominator == 0:
return 0
return self.numerator / self.denominator
def reset(self):
self.numerator = self.denominator = 0
def __add__(self, other: 'MeanMetric'):
return MeanMetric(self.numerator + other.numerator, self.denominator + other.denominator)
class SumMetric(Metric):
def __init__(self, sum_=0):
self.sum_ = sum_
def add(self, val):
self.sum_ += self.compute(val)
def many(self, vals: List[Any]):
self.sum_ += sum(self.compute(v) for v in vals)
def val(self):
return self.sum_
def reset(self):
self.sum_ = 0
def __add__(self, other: 'SumMetric'):
return SumMetric(self.sum_ + other.sum_)
class RealtimeMetric(Metric):
def __init__(self, val=0):
self.v = val
def add(self, val):
self.v = self.compute(val)
def many(self, vals: List[Any]):
self.add(vals[-1])
def val(self):
return self.v
def reset(self):
self.v = 0
def __add__(self, other):
return RealtimeMetric(self.v)
class PPLMetric(MeanMetric):
def val(self):
try:
return math.exp(super().val())
except OverflowError:
return super().val()
def __add__(self, other):
return PPLMetric(self.numerator + other.numerator, self.denominator + other.denominator)
class Metrics():
tb_writer = None
def __init__(self, opt: Dict[str, Any], accelerator, mode='train'):
self.metrics = {}
self.mode = mode
self.opt = opt
self.accelerator = accelerator
if Metrics.tb_writer is None and opt.logdir is not None and self.accelerator.is_main_process:
Metrics.tb_writer = SummaryWriter(opt.logdir)
def create_metric(self, metric_name: str, metric_obj: Metric):
assert metric_name not in self.metrics
self.metrics[metric_name] = metric_obj
def record_metric(self, metric_name: str, val: Any):
self.metrics[metric_name].add(val)
def record_metric_many(self, metric_name: str, vals: List[Any], counts: Optional[List[int]] = None):
if counts is None:
self.metrics[metric_name].many(vals)
else:
self.metrics[metric_name].many(vals, counts)
def reset(self, no_reset = ['global_exs']):
for k, v in self.metrics.items():
if k not in no_reset:
v.reset()
def all_gather_metrics(self):
with torch.no_grad():
metrics_tensor = {k: torch.tensor([v.val()], device=self.accelerator.device) for k, v in self.metrics.items()}
if self.accelerator.use_distributed:
gathered_metrics = self.accelerator.gather(metrics_tensor)
for metric_name, gathered_tensor in gathered_metrics.items():
if metric_name == 'global_exs':
gathered_metrics[metric_name] = gathered_tensor.sum()
else:
gathered_metrics[metric_name] = gathered_tensor.float().mean()
else:
gathered_metrics = metrics_tensor
gathered_metrics = {k: v.item() for k, v in gathered_metrics.items()}
return gathered_metrics
def write_tensorboard(self, global_step, gathered_metrics: Dict[str, float] = None):
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics
if self.tb_writer is not None:
for k, scalar in results.items():
title = f"{k}/{'train' if 'train' == self.mode else 'eval'}"
self.tb_writer.add_scalar(tag=title, scalar_value=scalar, global_step=global_step)
def flush(self):
if self.tb_writer is not None:
self.tb_writer.flush()
def display(self, global_step, data_size = None, gathered_metrics: Dict[str, float] = None):
if not self.accelerator.is_main_process:
return
results = self.all_gather_metrics() if gathered_metrics is None else gathered_metrics
log_str = ''
if data_size is not None and 'global_exs' in results:
print(f"=========== Step: {global_step}, Epoch: {(results['global_exs'] / data_size):.2f} ===========")
else:
print(f'=========== Step: {global_step} ===========')
for k, value in results.items():
if isinstance(value, float):
if k == 'lr':
value = f'{value:.3e}'
else:
value = f'{value:.4f}'
log_str += f'{k}: {value}\t'
print(log_str)
return results