forked from NVIDIA-AI-IOT/trt_pose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess_coco_person.py
95 lines (73 loc) · 2.51 KB
/
preprocess_coco_person.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import json
import tqdm
def remove_link(cat, kp_a_name, kp_b_name):
skeleton = cat["skeleton"]
keypoints = cat["keypoints"]
new_skeleton = []
for link in skeleton:
if (keypoints[link[0] - 1] == kp_a_name) and (
keypoints[link[1] - 1] == kp_b_name
):
continue # skip
if (keypoints[link[0] - 1] == kp_b_name) and (
keypoints[link[1] - 1] == kp_a_name
):
continue # skip
new_skeleton.append(link)
cat["skeleton"] = new_skeleton
def add_link(cat, kp_a_name, kp_b_name):
keypoints = cat["keypoints"]
cat["skeleton"].append(
[keypoints.index(kp_a_name) + 1, keypoints.index(kp_b_name) + 1]
)
def append_neck_keypoint(ann, cat):
keypoints = cat["keypoints"]
kps = ann["keypoints"]
l_idx = 3 * keypoints.index("left_shoulder")
r_idx = 3 * keypoints.index("right_shoulder")
x_neck = round((kps[l_idx] + kps[r_idx]) / 2.0)
y_neck = round((kps[l_idx + 1] + kps[r_idx + 1]) / 2.0)
v_l = kps[l_idx + 2]
v_r = kps[r_idx + 2]
if v_l == 0 or v_r == 0:
v_neck = 0
elif v_l == 1 or v_r == 1:
v_neck = 1
else:
v_neck = 2
kps += [x_neck, y_neck, v_neck]
def get_cat(data, cat_name):
return [c for c in data["categories"] if c["name"] == cat_name][0]
def get_anns(data, cat_id):
return [a for a in data["annotations"] if a["category_id"] == cat_id]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input_annotation_file", type=str, help="Path to COCO annotation file"
)
parser.add_argument(
"output_annotation_file", type=str, help="Path to COCO annotation file"
)
args = parser.parse_args()
print("Loading...")
with open(args.input_annotation_file, "r") as f:
data = json.load(f)
print("Preprocessing...")
cat = get_cat(data, "person")
cat_id = cat["id"]
anns = get_anns(data, cat_id)
for a in anns:
append_neck_keypoint(a, cat)
cat["keypoints"].append("neck")
remove_link(cat, "left_shoulder", "right_shoulder")
remove_link(cat, "left_shoulder", "left_hip")
remove_link(cat, "right_shoulder", "right_hip")
add_link(cat, "neck", "nose")
add_link(cat, "neck", "left_shoulder")
add_link(cat, "neck", "right_shoulder")
add_link(cat, "neck", "left_hip")
add_link(cat, "neck", "right_hip")
print("Saving...")
with open(args.output_annotation_file, "w") as f:
json.dump(data, f)