-
Notifications
You must be signed in to change notification settings - Fork 18
/
dataset.py
113 lines (83 loc) · 3.95 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import tensorflow as tf
from constants import HEIGHT, WIDTH, DEPTH, NUM_CLASSES, NUM_DATA_FILES, NUM_IMAGES
def record_dataset(filenames):
"""Returns an input pipeline Dataset from `filenames`."""
record_bytes = HEIGHT * WIDTH * DEPTH + 1
return tf.data.FixedLengthRecordDataset(filenames, record_bytes)
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
data_dir = os.path.join(data_dir, 'cifar-10-batches-bin')
assert os.path.exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
'CIFAR-10 data.')
if is_training:
return [
os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, NUM_DATA_FILES + 1)
]
else:
return [os.path.join(data_dir, 'test_batch.bin')]
def parse_record(raw_record):
"""Parse CIFAR-10 image and label from a raw record."""
# Every record consists of a label followed by the image, with a fixed number
# of bytes for each.
label_bytes = 1
image_bytes = HEIGHT * WIDTH * DEPTH
record_bytes = label_bytes + image_bytes
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
label = tf.one_hot(label, NUM_CLASSES)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(
record_vector[label_bytes:record_bytes], [DEPTH, HEIGHT, WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32)
return image, label
def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad(
image, HEIGHT + 8, WIDTH + 8)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image.
image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
"""Input_fn using the tf.data input pipeline for CIFAR-10 dataset.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
Returns:
A tuple of images and labels.
"""
dataset = record_dataset(get_filenames(is_training, data_dir))
if is_training:
# When choosing shuffle buffer sizes, larger sizes result in better
# randomness, while smaller sizes have better performance. Because CIFAR-10
# is a relatively small dataset, we choose to shuffle the full epoch.
dataset = dataset.shuffle(buffer_size=NUM_IMAGES['train'])
dataset = dataset.map(parse_record)
dataset = dataset.map(
lambda image, label: (preprocess_image(image, is_training), label), num_parallel_calls=4)
dataset = dataset.prefetch(2 * batch_size)
# We call repeat after shuffling, rather than before, to prevent separate
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
# Batch results by up to batch_size, and then fetch the tuple from the
# iterator.
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels