-
Notifications
You must be signed in to change notification settings - Fork 41
/
create_train_val_label.py
43 lines (30 loc) · 1.3 KB
/
create_train_val_label.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import json,os
import numpy as np
import tifffile as tiff
from PIL import Image
def get_label_from_palette(label_img, palette_file='Palette.json'):
with open(palette_file, 'r') as fp:
text = json.load(fp)
label = np.zeros((label_img.shape[0], label_img.shape[1]), dtype=np.uint8)
for i in range(label_img.shape[0]):
print(i)
for j in range(label_img.shape[1]):
assert list(label_img[i, j, :]) in list(text.values())
label[i, j] = int(list(text.keys())[list(text.values()).index(list(label_img[i, j, :]))])
return label
def main(path):
for pic in os.listdir(path):
if 'label' in pic:
print(pic)
# ---- read RGB label
label = Image.open(path + '/' +pic)
label = np.asarray(label)
# ----- another way
# label = tiff.imread(path + '/' +pic) # the code of this line can be run in windows system, which in ubuntu will get a error !!
label = get_label_from_palette(label)
tiff.imsave(path + '/' +pic[:-9] + 'new-L.tif',label)
if __name__ == '__main__':
train_path = 'dataset/train'
val_path = 'dataset/val'
main(train_path)
main(val_path)