forked from BrandonSmithJ/MDN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
metrics.py
216 lines (172 loc) · 5.75 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from .utils import ignore_warnings
from scipy import stats
import numpy as np
import functools
def validate_shape(func):
''' Decorator to flatten all function input arrays, and ensure shapes are the same '''
@functools.wraps(func)
def helper(*args, **kwargs):
flat = [a.flatten() if hasattr(a, 'flatten') else a for a in args]
flat_shp = [a.shape for a in flat if hasattr(a, 'shape')]
orig_shp = [a.shape for a in args if hasattr(a, 'shape')]
assert(all(flat_shp[0] == s for s in flat_shp)), f'Shapes mismatch in {func.__name__}: {orig_shp}'
return func(*flat, **kwargs)
return helper
def only_finite(func):
''' Decorator to remove samples which are nan in any input array '''
@validate_shape
@functools.wraps(func)
def helper(*args, **kwargs):
stacked = np.vstack(args)
valid = np.all(np.isfinite(stacked), 0)
assert(valid.sum()), f'No valid samples exist for {func.__name__} metric'
return func(*stacked[:, valid], **kwargs)
return helper
def only_positive(func):
''' Decorator to remove samples which are zero/negative in any input array '''
@validate_shape
@functools.wraps(func)
def helper(*args, **kwargs):
stacked = np.vstack(args)
valid = np.all(stacked > 0, 0)
assert(valid.sum()), f'No valid samples exist for {func.__name__} metric'
return func(*stacked[:, valid], **kwargs)
return helper
def label(name):
''' Label a function to aid in printing '''
def wrapper(func):
func.__name__ = name
return ignore_warnings(func)
return wrapper
# ============================================================================
'''
When executing a function, decorator order starts with the
outermost decorator and works its way down the stack; e.g.
@dec1
@dec2
def foo(): pass
def bar(): pass
And then foo == dec1(dec2(bar)). So, foo will execute dec1,
then dec2, then the original function.
Below, in rmsle (for example), we have:
rmsle = only_finite( only_positive( label(rmsle) ) )
This means only_positive() will get the input arrays only
after only_finite() removes any nan samples. As well, both
only_positive() and only_finite() will have access to the
function __name__ assigned by label().
For all functions below, y=true and y_hat=estimate
'''
@only_finite
@label('RMSE')
def rmse(y, y_hat):
''' Root Mean Squared Error '''
return np.mean((y - y_hat) ** 2) ** .5
@only_finite
@only_positive
@label('RMSLE')
def rmsle(y, y_hat):
''' Root Mean Squared Logarithmic Error '''
return np.mean(np.abs(np.log(y) - np.log(y_hat)) ** 2) ** 0.5
@only_finite
@label('NRMSE')
def nrmse(y, y_hat):
''' Normalized Root Mean Squared Error '''
return ((y - y_hat) ** 2).mean() ** .5 / y.mean()
@only_finite
@label('MAE')
def mae(y, y_hat):
''' Mean Absolute Error '''
return np.mean(np.abs(y - y_hat))
@only_finite
@label('MAPE')
def mape(y, y_hat):
''' Mean Absolute Percentage Error '''
return 100 * np.mean(np.abs((y - y_hat) / y))
@only_finite
@label('<=0')
def leqz(y, y_hat=None):
''' Less than or equal to zero (y_hat) '''
if y_hat is None: y_hat = y
return (y_hat <= 0).sum()
@validate_shape
@label('<=0|NaN')
def leqznan(y, y_hat=None):
''' Less than or equal to zero (y_hat) '''
if y_hat is None: y_hat = y
return np.logical_or(np.isnan(y_hat), y_hat <= 0).sum()
@only_finite
@only_positive
@label('MdSA')
def mdsa(y, y_hat):
''' Median Symmetric Accuracy '''
# https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
return 100 * (np.exp(np.median(np.abs(np.log(y_hat / y)))) - 1)
@only_finite
@only_positive
@label('MSA')
def msa(y, y_hat):
''' Mean Symmetric Accuracy '''
# https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
return 100 * (np.exp(np.mean(np.abs(np.log(y_hat / y)))) - 1)
@only_finite
@only_positive
@label('SSPB')
def sspb(y, y_hat):
''' Symmetric Signed Percentage Bias '''
# https://agupubs.onlinelibrary.wiley.com/doi/full/10.1002/2017SW001669
M = np.median( np.log(y_hat / y) )
return 100 * np.sign(M) * (np.exp(np.abs(M)) - 1)
@only_finite
@label('Bias')
def bias(y, y_hat):
''' Mean Bias '''
return np.mean(y_hat - y)
@only_finite
@only_positive
@label('R^2')
def r_squared(y, y_hat):
''' Logarithmic R^2 '''
slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
return r_value**2
@only_finite
@only_positive
@label('Slope')
def slope(y, y_hat):
''' Logarithmic slope '''
slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
return slope_
@only_finite
@only_positive
@label('Intercept')
def intercept(y, y_hat):
''' Locarithmic intercept '''
slope_, intercept_, r_value, p_value, std_err = stats.linregress(np.log10(y), np.log10(y_hat))
return intercept_
@validate_shape
@label('MWR')
def mwr(y, y_hat, y_bench):
'''
Model Win Rate - Percent of samples in which model has a closer
estimate than the benchmark.
y: true, y_hat: model, y_bench: benchmark
'''
y_bench[y_bench < 0] = np.nan
y_hat[y_hat < 0] = np.nan
y[y < 0] = np.nan
valid = np.logical_and(np.isfinite(y_hat), np.isfinite(y_bench))
diff1 = np.abs(y[valid] - y_hat[valid])
diff2 = np.abs(y[valid] - y_bench[valid])
stats = np.zeros(len(y))
stats[valid] = diff1 < diff2
stats[~np.isfinite(y_bench)] = 1
stats[~np.isfinite(y_hat)] = 0
return stats.sum() / np.isfinite(y).sum()
def performance(key, y, y_hat, metrics=[mdsa, sspb, slope, msa, rmsle, mae, leqznan], csv=False):
''' Return a string containing performance using various metrics.
y should be the true value, y_hat the estimated value. '''
y = y.flatten()
y_hat = y_hat.flatten()
try:
if csv: return f'{key},'+','.join([f'{f.__name__}:{f(y, y_hat)}' for f in metrics])
else: return f'{key:>12} | '+' '.join([f'{f.__name__}: {f(y, y_hat):>6.3f}' for f in metrics])
except Exception as e: return f'{key:>12} | Exception: {e}'