-
Notifications
You must be signed in to change notification settings - Fork 13
/
run_eval.py
71 lines (51 loc) · 2.53 KB
/
run_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# %% Create Eval Instance
import os
import time
from datetime import datetime, timedelta
import imp
import evaluation_wrapper
imp.reload(evaluation_wrapper)
import evaluation
imp.reload(evaluation)
import constants
import matplotlib.pyplot as plt
%matplotlib inline
eval = evaluation_wrapper.EvaluationWrapper('2021-04-03-18h-11m_batchsize_16_hg_4_loss_keras_mse_aug_medium_sigma4_learningrate_5.0e-03_opt_adam_gt-4kp_activ_sigmoid')
# eval = evaluation_wrapper.EvaluationWrapper('2021-04-01-21h-59m_batchsize_16_hg_4_loss_weighted_mse_aug_light_sigma4_learningrate_5.0e-03_opt_rmsProp_gt-4kp_activ_sigmoid_subset_0.50_lrfix')
# %% Run OKS
start = time.time()
epochs = [x for x in range(1,6)] + [10*x for x in range(1,8)]
eval.calculateMetric(constants.Metrics.oks, epochs, constants.Generator.val_gen, average_flip_prediction=False)
elapsed = time.time() - start
print("Total OKS average normal & flip time: {}".format(str(timedelta(seconds=elapsed))))
start = time.time()
eval.calculateMetric(constants.Metrics.oks, epochs, constants.Generator.val_gen, average_flip_prediction=True)
elapsed = time.time() - start
print("Total OKS time: {}".format(str(timedelta(seconds=elapsed))))
# %% Run PCK
start = time.time()
eval.calculateMetric(constants.Metrics.pck, epochs, constants.Generator.val_gen, average_flip_prediction=False)
eval.calculateMetric(constants.Metrics.pck, epochs, constants.Generator.val_gen, average_flip_prediction=True)
elapsed = time.time() - start
print("Total PCK time: {}".format(str(timedelta(seconds=elapsed))))
# %% Visualize representative batch heatmaps
start = time.time()
eval.visualizeHeatmaps(constants.Generator.representative_set_gen)
elapsed = time.time() - start
print("Total heatmap time: {}".format(str(timedelta(seconds=elapsed))))
# %% Visualize representative batch keypoints
start = time.time()
eval.visualizeKeypoints(constants.Generator.representative_set_gen)
elapsed = time.time() - start
print("Total keypoint time: {}".format(str(timedelta(seconds=elapsed))))
# %% Run on arbitrary images
img_name = 'IMG_3320.JPG'
name_no_extension = img_name.split('.')[0]
user_img_dir = 'user_img'
img_path = os.path.join(constants.DEFAULT_DATA_BASE_DIR, user_img_dir, img_name)
eval.predict_on_path(img_path, average_flip_prediction=False)
eval.predict_on_path(img_path, average_flip_prediction=True)
# %%
eval.predict_on_path(os.path.join(constants.DEFAULT_DATA_BASE_DIR, user_img_dir), average_flip_prediction=False)
eval.predict_on_path(os.path.join(constants.DEFAULT_DATA_BASE_DIR, user_img_dir), average_flip_prediction=True)
# %%