-
Notifications
You must be signed in to change notification settings - Fork 62
/
draw_plot.py
34 lines (30 loc) · 1.1 KB
/
draw_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import sklearn.metrics
import matplotlib
# Use 'Agg' so this program could run on a remote server
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
result_dir = './test_result'
def main():
models = sys.argv[1:]
for model in models:
x = np.load(os.path.join(result_dir, model +'_x' + '.npy'))
y = np.load(os.path.join(result_dir, model + '_y' + '.npy'))
f1 = (2 * x * y / (x + y + 1e-20)).max()
auc = sklearn.metrics.auc(x=x, y=y)
#plt.plot(x, y, lw=2, label=model + '-auc='+str(auc))
plt.plot(x, y, lw=2, label=model)
print(model + ' : ' + 'auc = ' + str(auc) + ' | ' + 'max F1 = ' + str(f1))
print(' P@100: {} | P@200: {} | P@300: {} | Mean: {}'.format(y[100], y[200], y[300], (y[100] + y[200] + y[300]) / 3))
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.ylim([0.3, 1.0])
plt.xlim([0.0, 0.4])
plt.title('Precision-Recall')
plt.legend(loc="upper right")
plt.grid(True)
plt.savefig(os.path.join(result_dir, 'pr_curve'))
if __name__ == "__main__":
main()