-
Notifications
You must be signed in to change notification settings - Fork 77
/
predict.py
executable file
·98 lines (88 loc) · 3.68 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from PIL import Image
import tensorflow as tf
import tflib
import tflib.ops
import tflib.network
from tqdm import tqdm
import numpy as np
import data_loaders
import time
import os
import json
import sys
import pyperclip
sys.path.append('./im2markup/scripts/utils')
from image_utils import *
import subprocess
import glob
BATCH_SIZE = 2
EMB_DIM = 80
ENC_DIM = 256
DEC_DIM = ENC_DIM*2
NUM_FEATS_START = 64
D = NUM_FEATS_START*8
V = 502
NB_EPOCHS = 50
H = 20
W = 50
X = tf.placeholder(shape=(None,None,None,None),dtype=tf.float32)
mask = tf.placeholder(shape=(None,None),dtype=tf.int32)
seqs = tf.placeholder(shape=(None,None),dtype=tf.int32)
learn_rate = tf.placeholder(tf.float32)
ctx = tflib.network.im2latex_cnn(X,NUM_FEATS_START,True)
out,state = tflib.ops.FreeRunIm2LatexAttention('AttLSTM',ctx,EMB_DIM,V,ENC_DIM,DEC_DIM,D,H,W)
predictions = tf.argmax(out[:,:,-V:],axis=2)
sess = tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=8))
init = tf.global_variables_initializer()
sess.run(init)
weights = np.load('weights_best_numpy.npy').tolist()
var_dict = {x.name: x for x in tf.get_collection('variables')}
for key in var_dict.keys():
if key not in weights.keys():
print key," not found!!"
else:
sess.run(var_dict[key].assign(weights[key]))
print "Initialized ",key
sess.run(var_dict['RNN/while/Embedding/Embedding:0'].assign(weights['Embedding/Embedding:0']))
sess.run(var_dict['RNN/while/MLP.1/MLP.1.W:0'].assign(weights['MLP.1/MLP.1.W:0']))
sess.run(var_dict['RNN/while/MLP.1/MLP.1.b:0'].assign(weights['MLP.1/MLP.1.b:0']))
properties = np.load('properties.npy').tolist()
def show():
batch_size=1
imgs = np.asarray(Image.open('tmp3.png').convert('YCbCr'))[:,:,0][None,None,:]
inp_seqs = np.zeros((batch_size,160)).astype('int32')
inp_seqs[:,0] = properties['char_to_idx']['#START']
tflib.ops.ctx_vector = []
idx_to_chars = lambda Y: ' '.join(map(lambda x: properties['idx_to_char'][x],Y))
visualize=False
inp_seqs = sess.run(predictions,feed_dict={X:imgs})
str = idx_to_chars(inp_seqs.flatten().tolist()).split('#END')[0].replace('\left','').replace('\\right','').replace('&','')
print "Latex sequence: ",str
pyperclip.copy('$'+str+'$')
def run_demo(filename=None,scale=2):
if filename:
file = '/home/rithesh/Downloads/%s.pdf'%(filename)
else:
file = np.random.choice(glob.glob('Papers_PDF/*'))
os.system('xdg-open ' + file)
num = str(input("Enter page number: ")-1)
os.system('convert -density 200 -quality 100 %s tmp.png'%(file+'[%s]'%num))
Image.open('tmp.png').show()
time.sleep(3)
os.system('import screenshot.png')
Image.open('screenshot.png').show()
while raw_input("Is the crop correct? (y/n) : ").lower() not in ['y','yes']:
os.system('pkill display')
Image.open('tmp.png').show()
time.sleep(3)
os.system('import screenshot.png')
os.system('pkill display')
Image.open('screenshot.png').show()
os.system('pkill display')
status = crop_image('screenshot.png', './tmp1.png', (600,60))
buckets = json.loads('[[240,100], [320,80], [400,80],[400,100], [480,80], [480,100], [560,80], [560,100], [640,80],[640,100], [720,80], [720,100], [720,120], [720, 200], [800,100],[800,320], [1000,200]]')
buckets_2 = json.loads('[[120,50], [160,40], [200,40],[200,50], [240,40], [240,50], [280,40], [280,50], [320,40],[320,50], [360,40], [360,50], [360,60], [360, 100], [400,50],[400,160], [500,100]]')
status = pad_group_image('./tmp1.png', './tmp2.png', [8,8,8,8], buckets)
status = downsample_image('./tmp2.png', './tmp3.png', scale)
show()
run_demo()