forked from michaelwiest/ucsd-cse-253-hw-3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
39 lines (31 loc) · 912 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import matplotlib.pyplot as plt
label_names = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck'
]
def plot_images(images, cls_true, cls_pred=None):
assert len(images) == len(cls_true) == 9
# Create figure with sub-plots.
fig, axes = plt.subplots(3, 3)
for i, ax in enumerate(axes.flat):
# plot the image
ax.imshow(images[i, :, :, :], interpolation='spline16')
# get its equivalent class name
cls_true_name = label_names[cls_true[i]]
if cls_pred is None:
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i])
else:
cls_pred_name = label_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)
ax.set_xlabel(xlabel)
ax.set_xticks([])
ax.set_yticks([])
plt.show()