forked from sugyan/face-generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
web.py
130 lines (106 loc) · 3.92 KB
/
web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import base64
import io
import os
import urllib.request
from flask import Flask, render_template, jsonify, request
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.dialects.postgresql import JSON
import numpy as np
import tensorflow as tf
from dcgan import DCGAN
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('port', 5000,
"""Application port.""")
tf.app.flags.DEFINE_string('checkpoint_path', '/tmp/g.ckpt',
"""Directory where to read model checkpoints.""")
# download checkpoint file
if not os.path.isfile(FLAGS.checkpoint_path):
print('No checkpoint file found')
urllib.request.urlretrieve(os.environ['CHECKPOINT_DOWNLOAD_URL'], FLAGS.checkpoint_path)
# DCGAN instance with specified batch size
def get_dcgan(batch_size):
return DCGAN(
batch_size=batch_size, f_size=6, z_dim=16,
gdepth1=216, gdepth2=144, gdepth3=96, gdepth4=64,
ddepth1=0, ddepth2=0, ddepth3=0, ddepth4=0)
# moments dictionary { <Tensor.name>: <values> }
def get_moments():
# graph for 128 batch
with tf.Graph().as_default() as g:
with tf.Session() as sess:
dcgan = get_dcgan(256)
dcgan.g(dcgan.z)
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
saver = tf.train.Saver(variables)
saver.restore(sess, FLAGS.checkpoint_path)
# get each means and variances
outputs = []
for op in g.get_operations():
if not (op.name.endswith('normalize/mean') or op.name.endswith('normalize/variance')):
continue
outputs.extend(op.outputs)
values = sess.run(outputs)
return {outputs[i].name: values[i] for i in range(len(outputs))}
# calculate once
moments = get_moments()
# start session for web app
sess = tf.Session()
# setup single image generator
dcgan = get_dcgan(1)
inputs = tf.placeholder(tf.float32, (dcgan.batch_size, dcgan.z_dim))
generate_image = dcgan.sample_images(1, 1, inputs)
# restore variables
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='g')
saver = tf.train.Saver(variables)
saver.restore(sess, FLAGS.checkpoint_path)
# Flask setup
app = Flask(__name__)
app.debug = True
app.config['SQLALCHEMY_DATABASE_URI'] = os.environ['DATABASE_URL']
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
app.config['DEFAULT_FEED_DICT'] = {}
db = SQLAlchemy(app)
class Offsets(db.Model):
id = db.Column(db.Integer, primary_key=True)
name = db.Column(db.String())
values = db.Column(JSON)
def __init__(self, name, values):
self.name = name
self.values = values
def __repr__(self):
return '<Offsets %r>' % self.name
def serialize(self):
return {
'id' : self.id,
'name' : self.name,
'values': self.values,
}
for op in sess.graph.get_operations():
for output in op.outputs:
if output.name in moments:
app.config['DEFAULT_FEED_DICT'][output] = moments[output.name]
@app.route('/api/offsets')
def offsets():
return jsonify(offsets=[o.serialize() for o in Offsets.query.order_by(Offsets.id).all()])
@app.route('/api/generate', methods=['POST'])
def image():
feed_dict = {inputs: [request.get_json()]}
feed_dict.update(app.config['DEFAULT_FEED_DICT'])
result = sess.run(generate_image, feed_dict=feed_dict)
return jsonify(result='data:image/png;base64,' + base64.b64encode(result).decode())
@app.route('/')
def root():
return render_template('index.html')
@app.route('/<foo>')
def index(foo):
return root()
@app.context_processor
def processor():
def javascript(filename):
if 'DEBUG' in os.environ:
return 'http://localhost:8080/static/js/' + filename
else:
return '/static/js/' + filename
return dict(js=javascript)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=FLAGS.port)