-
Notifications
You must be signed in to change notification settings - Fork 3
/
TestAll.py
143 lines (114 loc) · 4.35 KB
/
TestAll.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import sys, os
import time
import matplotlib.pyplot as plt
import textdistance
sys.path.append(os.getcwd())
# crnn packages
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
import models.crnn as crnn
import alphabets
str1 = alphabets.chinese_3000
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--images_path', type=str, default='test_images/test5.jpg', help='the path to your images')
opt = parser.parse_args()
validation_pth = 'to_lmdb/test_width'
# crnn params
# 3p6m_third_ac97p8.pth
crnn_model_path = 'trained_models/crnn_Rec_done_69_5000.pth'
alphabet = str1
nclass = len(alphabet)+1
Batch_Test_Flag = True
path_prefix = 'to_lmdb/test_width/'
f = open('to_lmdb/test_width.txt', "r", encoding='utf-8')
lines = f.readlines()
f.close()
def get_image_path_and_labels(dir, path_prefix):
label_dict = {}
for i in range(len(lines)):
line_text_1 = lines[i]
imgPth = os.path.join(path_prefix,line_text_1.split()[0])
label = line_text_1.split()[-1]
label_dict.update({imgPth:label})
return label_dict
tesing_dataset = get_image_path_and_labels(validation_pth, path_prefix)
# crnn文本信息识别
def crnn_recognition(imgpth, model, tesing_dataset, total_correct_num, total_string_length):
cropped_image = Image.open(imgpth)
converter = utils.strLabelConverter(alphabet)
image = cropped_image.convert('L')
##
w = int(image.size[0] / (280 * 1.0 / 180))
# w = image.size[0]
# w = int(image.size[0] / (32 * 1.0 / image.size[1]))
transformer = dataset.resizeNormalize((w, 32))
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
ground_truth = tesing_dataset.get(imgpth)
correct_num = int(len(ground_truth) * textdistance.levenshtein.normalized_similarity(ground_truth, sim_pred))
string_length = len(ground_truth)
#check = ground_truth == sim_pred
print('results: {0}, gt: {1}'.format(sim_pred, ground_truth))
return correct_num, string_length
def crnn_single_test(cropped_image, model):
converter = utils.strLabelConverter(alphabet)
image = cropped_image.convert('L')
##
w = int(image.size[0] / (280 * 1.0 / 180))
# w = image.size[0]
# w = int(image.size[0] / (32 * 1.0 / image.size[1]))
transformer = dataset.resizeNormalize((w, 32))
image = transformer(image)
if torch.cuda.is_available():
image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)
model.eval()
preds = model(image)
_, preds = preds.max(2)
preds = preds.transpose(1, 0).contiguous().view(-1)
preds_size = Variable(torch.IntTensor([preds.size(0)]))
sim_pred = converter.decode(preds.data, preds_size.data, raw=False)
print('results: {0}'.format(sim_pred))
if __name__ == '__main__':
# crnn network
model = crnn.CRNN(32, 1, nclass, 256)
if torch.cuda.is_available():
model = model.cuda()
print('loading pretrained model from {0}'.format(crnn_model_path))
# 导入已经训练好的crnn模型
model.load_state_dict(torch.load(crnn_model_path))
started = time.time()
## read an image
# image = Image.open(opt.images_path)
if Batch_Test_Flag:
total_correct_num = 0
total_string_length = 0
for stp in range(len(tesing_dataset)):
imgpth = list(tesing_dataset.keys())[stp]
# image = Image.open(imgpth)
correct_num, string_length = crnn_recognition(imgpth, model, tesing_dataset, total_correct_num, total_string_length)
total_correct_num += correct_num
total_string_length += string_length
finished = time.time()
print('elapsed time: {0}'.format(finished-started))
print('Correct Rate: {0}'.format(total_correct_num / total_string_length))
else:
image = Image.open(opt.images_path)
crnn_single_test(image, model)
finished = time.time()
print('elapsed time: {0}'.format(finished-started))