-
Notifications
You must be signed in to change notification settings - Fork 231
/
setup_mnist.py
94 lines (77 loc) · 3.13 KB
/
setup_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
## setup_mnist.py -- mnist data and model loading code
##
## Copyright (C) 2016, Nicholas Carlini <nicholas@carlini.com>.
##
## This program is licenced under the BSD 2-Clause licence,
## contained in the LICENCE file in this directory.
import tensorflow as tf
import numpy as np
import os
import pickle
import gzip
import urllib.request
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.models import load_model
def extract_data(filename, num_images):
with gzip.open(filename) as bytestream:
bytestream.read(16)
buf = bytestream.read(num_images*28*28)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = (data / 255) - 0.5
data = data.reshape(num_images, 28, 28, 1)
return data
def extract_labels(filename, num_images):
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
labels = np.frombuffer(buf, dtype=np.uint8)
return (np.arange(10) == labels[:, None]).astype(np.float32)
class MNIST:
def __init__(self):
if not os.path.exists("data"):
os.mkdir("data")
files = ["train-images-idx3-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
"train-labels-idx1-ubyte.gz",
"t10k-labels-idx1-ubyte.gz"]
for name in files:
urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/' + name, "data/"+name)
train_data = extract_data("data/train-images-idx3-ubyte.gz", 60000)
train_labels = extract_labels("data/train-labels-idx1-ubyte.gz", 60000)
self.test_data = extract_data("data/t10k-images-idx3-ubyte.gz", 10000)
self.test_labels = extract_labels("data/t10k-labels-idx1-ubyte.gz", 10000)
VALIDATION_SIZE = 5000
self.validation_data = train_data[:VALIDATION_SIZE, :, :, :]
self.validation_labels = train_labels[:VALIDATION_SIZE]
self.train_data = train_data[VALIDATION_SIZE:, :, :, :]
self.train_labels = train_labels[VALIDATION_SIZE:]
class MNISTModel:
def __init__(self, restore, session=None):
self.num_channels = 1
self.image_size = 28
self.num_labels = 10
model = Sequential()
model.add(Conv2D(32, (3, 3),
input_shape=(28, 28, 1)))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(200))
model.add(Activation('relu'))
model.add(Dense(10))
model.load_weights(restore)
self.model = model
def predict(self, data):
return self.model(data)