-
Notifications
You must be signed in to change notification settings - Fork 2
/
mnist_export.py
114 lines (98 loc) · 4.39 KB
/
mnist_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Export given TensorFlow model.
The model is a pretrained "MNIST", which saved as TensorFlow model checkpoint. This program
simply uses TensorFlow SavedModel to
export the trained model with proper signatures that can be loaded by standard
tensorflow_model_server.
Usage: mnist_export.py [--model_version=y] [--checkpoint_dir=checkpoint_oss_path] export_dir
"""
import os
import sys
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
from tensorflow.examples.tutorials.mnist import input_data as mnist_input_data
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the exported model.')
tf.app.flags.DEFINE_string('checkpoint_path', None, 'Checkpoints path.')
FLAGS = tf.app.flags.FLAGS
def main(_):
if len(sys.argv) < 2 or sys.argv[-1].startswith('-'):
print('Usage: mnist_export.py '
'[--model_version=y] [--checkpoint_path=checkpoint_path_local/oss] export_dir')
sys.exit(-1)
if FLAGS.model_version <= 0:
print 'Please specify a positive value for exported serveable version number.'
sys.exit(-1)
if not FLAGS.checkpoint_path:
print 'Please specify the correct path where checkpoints stored locally or in OSS.'
sys.exit(-1)
default_meta_graph_suffix='.meta'
ckpt_path=FLAGS.checkpoint_path
meta_graph_file=ckpt_path + default_meta_graph_suffix
with tf.Session() as new_sess:
# with new_sess.graph.as_default():
# tf.reset_default_graph()
# new_sess.run(tf.initialize_all_variables())
new_saver = tf.train.import_meta_graph(meta_graph_file, clear_devices=True) #'/test/mnistoutput/ckpt.meta')
new_saver.restore(new_sess, ckpt_path) #'/test/mnistoutput/ckpt')
new_values = tf.get_collection('values')[0]
print(new_values)
new_graph = tf.get_default_graph()
new_x = new_graph.get_tensor_by_name('x:0')
print(new_x)
new_y = new_graph.get_tensor_by_name('y:0')
print(new_y)
new_serialized_tf_example = new_graph.get_tensor_by_name('tf_example:0')
print(new_serialized_tf_example)
new_prediction_classes = tf.get_collection('prediction_classes')[0]
print(new_prediction_classes)
var = new_graph.get_tensor_by_name("Variable:0")
print(var)
# Export model
# WARNING(break-tutorial-inline-code): The following code snippet is
# in-lined in tutorials, please update tutorial documents accordingly
# whenever code changes.
export_path_base = sys.argv[-1]
export_path = os.path.join(
compat.as_bytes(export_path_base),
compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', export_path
builder = saved_model_builder.SavedModelBuilder(export_path)
# Build the signature_def_map.
classification_inputs = utils.build_tensor_info(new_serialized_tf_example)
classification_outputs_classes = utils.build_tensor_info(new_prediction_classes)
classification_outputs_scores = utils.build_tensor_info(new_values)
classification_signature = signature_def_utils.build_signature_def(
inputs={signature_constants.CLASSIFY_INPUTS: classification_inputs},
outputs={
signature_constants.CLASSIFY_OUTPUT_CLASSES:
classification_outputs_classes,
signature_constants.CLASSIFY_OUTPUT_SCORES:
classification_outputs_scores
},
method_name=signature_constants.CLASSIFY_METHOD_NAME)
tensor_info_x = utils.build_tensor_info(new_x)
tensor_info_y = utils.build_tensor_info(new_y)
prediction_signature = signature_def_utils.build_signature_def(
inputs={'images': tensor_info_x},
outputs={'scores': tensor_info_y},
method_name=signature_constants.PREDICT_METHOD_NAME)
legacy_init_op = tf.group(tf.initialize_all_tables(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
new_sess, [tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op,
clear_devices=True
)
builder.save()
print 'Done exporting!'
if __name__ == '__main__':
tf.app.run()