-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization.py
99 lines (76 loc) · 2.89 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Implements functions for displaying keypoints on images.
"""
import torch
import matplotlib.pyplot as plt
import numpy as np
__author__ = "Victor Mawusi Ayi <ayivima@hotmail.com>"
# Set up for GPU use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fltTensor = (
torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
)
def plotter(
model,
test_loader,
plotrows=5,
plotcolumns=8,
showactual=False,
figsize=(17,10),
markersize=20
):
"""Displays images with an overlay of predicted keypoints,
and (optionally) actual kepoints.
Arguments
---------
:model: A trained (or untrained) model.
:test_loader: A generator for loading image and keypoint data.
:plotrows: The number of rows for image plotting.
:plotcolumns: The number of columns for image plotting.
Value must not exceed the batch size of the test loader.
:showactual: Specifies whether actual keypoints should be plotted
in addition to predicted keypoints.
:figsize: The size of the plot.
:markersize: The size of markers used for keypoint coordinates.
"""
f, axs = plt.subplots(plotrows, plotcolumns, figsize=figsize)
model = model.to(device)
# Convert test_loader into an iterator
test_loader = iter(test_loader)
# set up function for plotting keypoints
pointsplot = lambda axiz, pts, color: axiz.scatter(
pts[:, 0],
pts[:, 1],
s=markersize,
marker='.',
c=color
)
if len(axs.shape) == 1: axs = axs.reshape(1, -1)
for ax_ in axs:
# > Get next batch of images and keypoints
# > Convert images to FloatTensors
# > Obtain model predictions for image
# > Flatten keypoints
images, gt_pts = test_loader.next()
images = images.type(fltTensor)
output_pts = model(images)
output_pts = output_pts.view(output_pts.size()[0], 68, -1)
for i, ax in enumerate(ax_):
# Convert image to numpy image
# and convert it to numpy image format
if torch.cuda.is_available():
images = images.cpu()
output_pts = output_pts.cpu()
image = images[i].data.numpy()
image = np.transpose(image, (1, 2, 0))
# Remove transformations from predicted keypoints
prediction = output_pts[i].data.numpy() * 50.0 + 100
# Plot predicted keypoints on image
ax.imshow(np.squeeze(image), cmap='gray')
pointsplot(ax, prediction, 'm')
# plot ground truth points as green pts
if showactual:
actual_keypts = gt_pts[i] * 50.0 + 100
pointsplot(ax, actual_keypts, 'g')
ax.axis('off')
plt.show()