-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_tid.py
51 lines (40 loc) · 1.78 KB
/
gen_tid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import csv
import os
import numpy as np
import json
train_ratio = 0.8
labels_file_name = '/path/to/TID2013/mos_with_names.txt'
train_file_name = '/path/to/save/metadata/TID2013/metadata_train.json'
validate_file_name = '/path/to/save/metadata/TID2013/metadata_validate.json'
deprecated_class = []
deprecated_img = []
listt = lambda x:[[y[i] for y in x] for i in range(len(x[0]))]
if __name__ == '__main__':
with open(labels_file_name) as f:
reader = csv.reader(f, delimiter = ' ')
reader = list(reader)
item_list = []
for item in reader:
index = item[1].split(sep = '.')[0][1:].split(sep = '_')
index = [int(iitem) for iitem in index]
item[0] = float(item[0])
item_list.append(index + item)
item_list_deprecation = [item for item in item_list
if (item[0] not in deprecated_img)
and (item[1] not in deprecated_class)]
valid_imgs = list(set(listt(item_list_deprecation)[0]))
valid_imgs.sort()
number_imgs = valid_imgs.__len__()
number_train_imgs = int(number_imgs*train_ratio)
valid_imgs = np.array(valid_imgs)
random_index = np.random.choice(number_imgs, number_imgs, replace=False)
indexs_train_imgs = valid_imgs[random_index[0:number_train_imgs]]
indexs_validate_imgs = valid_imgs[random_index[number_train_imgs:]]
item_list_train = [item for item in item_list_deprecation
if item[0] in indexs_train_imgs]
item_list_validate = [item for item in item_list_deprecation
if item[0] in indexs_validate_imgs]
with open(train_file_name, 'w') as f:
json.dump(item_list_train, f)
with open(validate_file_name, 'w') as f:
json.dump(item_list_validate, f)