-
Notifications
You must be signed in to change notification settings - Fork 11
/
dataset.py
154 lines (119 loc) · 5.33 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2017 Hiroaki Santo
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import cv2
import numpy as np
import tqdm
class DpsnDataset(object):
def __init__(self, dataset_path="./dataset/", name="blobs_merl"):
self.dataset_path = dataset_path
self.data_list = glob.glob(os.path.join(self.dataset_path, "*/", "*/"))
self.name = name
tmp_list = glob.glob(os.path.join(self.data_list[0], "[0-9]*.png"))
self.light_num = len(tmp_list)
self.img_size = cv2.imread(tmp_list[0])[:, :, 0].shape
print("light_num: {}".format(self.light_num))
print("image_size: {}".format(self.img_size))
print("data_num: {}".format(len(self)))
np.random.seed(1)
self.random_indices = np.random.permutation(len(self))
def data_path2name(self, path):
dir_path, _ = os.path.split(path)
dir_path, brdf_name = os.path.split(dir_path)
_, obj_name = os.path.split(dir_path)
return obj_name, brdf_name
def __load_normal_png(self, n_path):
n_img = cv2.imread(n_path)[:, :, ::-1]
m, n, _ = n_img.shape
N = n_img.reshape(-1, 3).T
N = N.astype(np.float32) / 255. * 2. - 1.
for i in range(m * n):
norm = np.linalg.norm(N[:, i])
if norm != 0:
N[:, i] /= norm
mask = np.ones(shape=(m * n))
n_img = n_img.reshape(-1, 3).T
for i in range(m * n):
if np.linalg.norm(n_img[:, i]) == 0:
mask[i] = 0
N[:, mask == 0] = 0
return N, m, n, mask
def __len__(self):
return len(self.data_list)
def load_data(self, root_path):
def_png_path = os.path.join(root_path, "{light_index}.png")
m, n = self.img_size
M = np.zeros(shape=(m * n, self.light_num, 3), dtype=np.float32)
for l in range(self.light_num):
m_img = cv2.imread(def_png_path.format(light_index=l), cv2.IMREAD_UNCHANGED)[:, :, ::-1]
# m_img = cv2.imread(def_png_path.format(light_index=l))[:, :, ::-1]
M[:, l, :] = m_img.reshape(-1, 3)
obj_name, brdf_name = self.data_path2name(root_path + "/")
N, m, n, mask = self.__load_normal_png(os.path.join(self.dataset_path, obj_name, "{}.png".format(obj_name)))
return M, N, mask
def get_batch(self, index, image_num):
"""
:param index:
:param image_num: This does not mean the number of batch data. Number of pixels in image_num images becomes number of data.
:return:
"""
indices = np.arange(index, index + image_num)
indices %= len(self)
indices = self.random_indices[indices]
batch_normal = []
batch_mess = []
for i in indices:
M, N, mask = self.load_data(self.data_list[i])
M = M.astype(float) / np.max(M)
for p in range(N.shape[1]):
if mask[p] == 0:
continue
if np.min(np.linalg.norm(M[p, :, :], axis=0)) == 0:
continue
for color in range(3):
batch_normal.append(N[:, p])
batch_mess.append(M[p, :, color])
return np.array(batch_normal, dtype=np.float32), np.array(batch_mess, dtype=np.float32)
def save_as_tfrecord(self):
print("[*] save_as_tfrecord()")
import tensorflow as tf
tfwriter = tf.python_io.TFRecordWriter(
os.path.join(self.dataset_path, "{}_{}.tfrecord".format(type(self).__name__, self.name)))
try:
for i in tqdm.tqdm(range(0, len(self), 30)):
normal, mess = self.get_batch(index=i, image_num=30)
print("serialize data: {}, {}".format(normal.shape, mess.shape))
for j in np.random.permutation(len(normal)):
n_ = normal[j, :].astype(np.float32)
m_ = mess[j, :].astype(np.float32)
record = tf.train.Example(features=tf.train.Features(feature={
'normal': tf.train.Feature(float_list=tf.train.FloatList(value=n_.reshape(-1).tolist())),
'mess': tf.train.Feature(
float_list=tf.train.FloatList(value=m_.reshape(-1).tolist())),
}))
tfwriter.write(record.SerializeToString())
finally:
tfwriter.close()
def load_from_tfrecord(self):
import tensorflow as tf
data_path = os.path.join(self.dataset_path, "{}_{}.tfrecord".format(type(self).__name__, self.name))
assert os.path.exists(data_path), data_path
filename_queue = tf.train.string_input_producer([data_path], num_epochs=None)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'normal': tf.FixedLenFeature([3], tf.float32),
'mess': tf.FixedLenFeature([self.light_num], tf.float32),
})
return features["normal"], features["mess"]
if __name__ == '__main__':
import params
dataset = DpsnDataset(dataset_path=os.path.join(params.DATASET_PATH, "train"))
dataset.save_as_tfrecord()