-
Notifications
You must be signed in to change notification settings - Fork 1
/
lsun.py
102 lines (73 loc) · 2.75 KB
/
lsun.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
"""
import lmdb
import numpy
import os
import pickle
import scipy.misc
import StringIO
class Lsun(object):
"""
A utility to read LSUN data set.
http://lsun.cs.princeton.edu/2016/
"""
@staticmethod
def load_keys(path_lsun_dir):
"""
Load keys from a pickle file. All key of the database will be dumped
into a pickle file if there is none.
"""
path_keys = os.path.join(path_lsun_dir, 'keys.pkl')
if os.path.isfile(path_keys):
with open(path_keys, 'r') as kf:
return pickle.Unpickler(kf).load()
print 'generating keys of lmdb: ' + path_lsun_dir
keys = []
with lmdb.open(path_lsun_dir) as env:
with env.begin(write=False) as txn:
with txn.cursor() as cursor:
keys_iter = cursor.iternext_nodup(keys=True, values=False)
keys_count = env.stat()['entries']
for idx, key in enumerate(keys_iter):
keys.append(key)
if idx % 1000 == 0:
print 'found keys: {} / {}'.format(idx, keys_count)
with open(path_keys, 'w') as kf:
pickle.Pickler(kf).dump(keys)
return keys
def __init__(self, path_lsun_dir):
"""
"""
self._lmdb_path = path_lsun_dir
self._lmdb_keys = Lsun.load_keys(path_lsun_dir)
self._key_indice = numpy.arange(len(self._lmdb_keys))
self._key_position = 0
numpy.random.shuffle(self._key_indice)
def next_batch(self, batch_size):
"""
Get next batch_size images from the database.
All images are resized to 25% (either 64x? or ?x64).
All pixels are remapped to range between -1.0 ~ +1.0.
"""
begin = self._key_position
self._key_position += batch_size
if self._key_position > len(self._key_indice):
numpy.random.shuffle(self._key_indice)
begin = 0
self._key_position = batch_size
assert batch_size <= len(self._key_indice)
end = self._key_position
images = []
with lmdb.open(self._lmdb_path) as env:
with env.begin(write=False) as txn:
with txn.cursor() as cursor:
for i in xrange(begin, end):
val = cursor.get(self._lmdb_keys[self._key_indice[i]])
sio = StringIO.StringIO(val)
img = scipy.misc.imread(sio)
img = scipy.misc.imresize(img, 25)
img = img.astype(numpy.float32)
img /= 127.5
img -= 1.0
images.append(img)
return images