-
Notifications
You must be signed in to change notification settings - Fork 60
/
utils.py
71 lines (61 loc) · 2.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from sklearn.metrics import roc_curve,roc_auc_score
def get_default_device():
"""Pick GPU if available, else CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
"""Move tensor(s) to chosen device"""
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
def plot_history(history):
losses1 = [x['val_loss1'] for x in history]
losses2 = [x['val_loss2'] for x in history]
plt.plot(losses1, '-x', label="loss1")
plt.plot(losses2, '-x', label="loss2")
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.title('Losses vs. No. of epochs')
plt.grid()
plt.show()
def histogram(y_test,y_pred):
plt.figure(figsize=(12,6))
plt.hist([y_pred[y_test==0],
y_pred[y_test==1]],
bins=20,
color = ['#82E0AA','#EC7063'],stacked=True)
plt.title("Results",size=20)
plt.grid()
plt.show()
def ROC(y_test,y_pred):
fpr,tpr,tr=roc_curve(y_test,y_pred)
auc=roc_auc_score(y_test,y_pred)
idx=np.argwhere(np.diff(np.sign(tpr-(1-fpr)))).flatten()
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.plot(fpr,tpr,label="AUC="+str(auc))
plt.plot(fpr,1-fpr,'r:')
plt.plot(fpr[idx],tpr[idx], 'ro')
plt.legend(loc=4)
plt.grid()
plt.show()
return tr[idx]
def confusion_matrix(target, predicted, perc=False):
data = {'y_Actual': target,
'y_Predicted': predicted
}
df = pd.DataFrame(data, columns=['y_Predicted','y_Actual'])
confusion_matrix = pd.crosstab(df['y_Predicted'], df['y_Actual'], rownames=['Predicted'], colnames=['Actual'])
if perc:
sns.heatmap(confusion_matrix/np.sum(confusion_matrix), annot=True, fmt='.2%', cmap='Blues')
else:
sns.heatmap(confusion_matrix, annot=True, fmt='d')
plt.show()