-
Notifications
You must be signed in to change notification settings - Fork 56
/
shapenet_metadata.py
86 lines (73 loc) · 2.82 KB
/
shapenet_metadata.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import json
DATASET_DIRECTORY = "data/shapenet/"
MIN_SAMPLES_PER_CATEGORY = 2000
from util import device
class ShapenetCategory():
def __init__(self, name, id, count):
self.name = name
self.id = id
self.is_root = True
self.children = []
self.count = count
self.label = None
def print(self, depth = 0):
print(' ' * depth + self.name + '({:d})'.format(self.count))
for child in self.children:
child.print(depth = depth + 1)
def get_directory(self):
return os.path.join(DATASET_DIRECTORY, str(self.id).rjust(8, '0'))
class ShapenetMetadata():
def __init__(self):
self.clip_sdf = True
self.rescale_sdf = True
self.load_categories()
self.labels = None
def load_categories(self):
taxonomy_filename = os.path.join(DATASET_DIRECTORY, "taxonomy.json")
if not os.path.isfile(taxonomy_filename):
taxonomy_filename = 'examples/shapenet_taxonomy.json'
file_content = open(taxonomy_filename).read()
taxonomy = json.loads(file_content)
categories = dict()
for item in taxonomy:
id = int(item['synsetId'])
category = ShapenetCategory(item['name'], id, item['numInstances'])
categories[id] = category
for item in taxonomy:
id = int(item['synsetId'])
category = categories[id]
for str_id in item["children"]:
child_id = int(str_id)
category.children.append(categories[child_id])
categories[child_id].is_root = False
self.categories = [item for item in categories.values() if item.is_root and item.count >= MIN_SAMPLES_PER_CATEGORY]
self.categories = sorted(self.categories, key=lambda item: item.id)
self.categories_by_id = {item.id : item for item in self.categories}
self.label_count = len(self.categories)
for i in range(len(self.categories)):
self.categories[i].label = i
def get_color(self, label):
if label == 2:
return (0.9, 0.1, 0.14) # red
elif label == 1:
return (0.8, 0.7, 0.1) # yellow
elif label == 6:
return (0.05, 0.5, 0.05) # green
elif label == 5:
return (0.1, 0.2, 0.9) # blue
elif label == 4:
return (0.46, 0.1, 0.9) # purple
elif label == 3:
return (0.9, 0.1, 0.673) # purple
elif label == 0:
return (0.01, 0.6, 0.9) # cyan
else:
return (0.7, 0.7, 0.7)
shapenet = ShapenetMetadata()
if __name__ == "__main__":
for category in sorted(shapenet.categories, key=lambda c: -c.count):
print('{:d}: {:s} - {:d}'.format(
category.label,
category.name,
category.count))