-
Notifications
You must be signed in to change notification settings - Fork 0
/
readnumber.py
110 lines (95 loc) · 3.88 KB
/
readnumber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import tensorflow as tf
from models.multilayer import MLP
from models.convolution import ConvNet, MultiLogitDigitRecognizer
from inputs.datasource import *
from numberlocator import find_number_in_image, get_locator_model
from flags import FLAGS
from image import draw_bbox
from PIL import Image
from svhn_models import *
from synthetic_models import *
import sys
'''This module contains helper functions that train various classifier types'''
def get_best_svhn_model():
'''manually updated - returns the current best SVHN joint model'''
return get_svhn_joint_v5()
def train_synthetic_joint_classifier():
model = get_synthetic_joint_v1()
data = SyntheticJointDataSource()
model.train(data)
def train_synthetic_digit_classifier(digit):
model = get_synthetic_digit_model(digit)
data = SyntheticPerDigitDataSource(digit)
model.train(data)
def train_synthetic_length_classifier():
model = get_synthetic_length_model()
data = SyntheticLengthDataSource()
model.train(data)
def train_svhn_digit_classifier(digit):
model = get_svhn_digit_model(pos)
data = SVHNDataSource(digit=digit)
model.train(data)
def train_svhn_length_classifier():
model = get_svhn_length_model()
data = SVHNDataSource(length_only=True)
model.train(data)
def train_svhn_joint_classifier():
model = get_best_svhn_model()
data = SVHNDataSource(img_dirs = [FLAGS.svhn_extra_dir, FLAGS.svhn_train_dir],
#adjust dataset to 50k samples of each length
#undersample = {2: 0.5, 3: 0.5},
#oversample = {1: 25000, 2: 12000, 3:22000, 4: 29000, 5:30000}
)
model.train(data)
def locate_and_read_number(image, locator_model, classifier_model, bbox=False):
'''
This function locates all potential numbers in an image, run them through the specified
classifier, and returns the longest result with the highest confidence
'''
found_numbers = find_number_in_image(locator_model, image, draw_steps=False)
if not found_numbers:
return
predictions = []
for bbox in found_numbers:
patch = image.crop(bbox)
prediction, prediction_confidence = classifier_model.classify_image(patch)
predictions.append((prediction, bbox, prediction_confidence))
max_prediction_length = len(max(predictions, key = lambda x: len(x[0]))[0])
acceptable_lengths = (max_prediction_length, max_prediction_length - 1)
#Select the result with the highest confidence, provided that the number of digits is
#within 1 of the longest prediction. This is necessary to weed out scenarios where
#the number locator finds a good bounding box for a subset of digits that ends up
#resulting in the highest confidence.
sorted_by_confidence = sorted([i for i in predictions if len(i[0]) in acceptable_lengths], key = lambda x: x[2], reverse=True)
best_prediction, bbox, _ = sorted_by_confidence[0]
if bbox:
draw_bbox(image, bbox)
return best_prediction
if __name__ == "__main__":
if FLAGS.train:
svhn = FLAGS.svhn
synthetic = FLAGS.synthetic
if (svhn and synthetic) or (not svhn and not synthetic):
print "must specify one of 'svhn' or 'synthetic'"
sys.exit(0)
if svhn:
if FLAGS.joint:
train_svhn_joint_classifier()
elif FLAGS.digit:
train_svhn_digit_classifier(int(FLAGS.digit))
elif FLAGS.length:
train_svhn_length_classifier()
else:
if FLAGS.joint:
train_synthetic_joint_classifier()
elif FLAGS.digit:
train_synthetic_digit_classifier(int(FLAGS.digit))
elif FLAGS.length:
train_synthetic_length_classifier()
elif FLAGS.run:
locator_model = get_locator_model()
number_model = get_best_svhn_model()
image = Image.open(FLAGS.find)
print locate_and_read_number(image, locator_model, number_model, bbox=True)
image.show()