-
Notifications
You must be signed in to change notification settings - Fork 38
/
sample.py
29 lines (26 loc) · 998 Bytes
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import tensorflow as tf
import numpy as np
import model as m
from utils import *
from config import *
data_loader = DataLoader(args.batch_size, args.T, args.data_scale,
chars=args.chars, points_per_char=args.points_per_char)
str = 'a quick brown fox jumps over the lazy dog'
# str = 'aaaaabbbbbccccc'
args.U = len(str)
args.c_dimension = len(data_loader.chars) + 1
args.T = 1
args.batch_size = 1
args.action = 'sample'
model = m.Model(args)
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('save_%s' % args.mode)
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
if args.mode == 'predict':
strokes = model.sample(sess, 800)
if args.mode == 'synthesis':
str_vec = vectorization(str, data_loader.char_to_indices)
strokes = model.sample(sess, len(str) * args.points_per_char, str=str_vec)
# print strokes
draw_strokes_random_color(strokes, factor=0.1, svg_filename='sample' + '.normal.svg')