Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

valid det inference #36

Merged
merged 1 commit into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/det/det_db_mv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Global:
test_batch_size_per_card: 16
image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/
checkpoints:
save_res_path: ./output/det_db/predicts_db.txt
save_inference_dir:
Expand Down
4 changes: 2 additions & 2 deletions ppocr/data/det/dataset_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def __call__(self, mode):

def batch_iter_reader():
batch_outs = []
for img_path, img_name in img_list:
for img_path in img_list:
img = cv2.imread(img_path)
if img is None:
logger.info("load image error:" + img_path)
continue
outs = process_function(img)
outs.append(img_name)
outs.append(img_path)
batch_outs.append(outs)
if len(batch_outs) == batch_size:
yield batch_outs
Expand Down
16 changes: 14 additions & 2 deletions tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
from ppocr.data.det.db_process import DBProcessTest
from ppocr.postprocess.db_postprocess import DBPostProcess
from ppocr.postprocess.east_postprocess import EASTPostPocess
from ppocr.utils.utility import get_image_file_list
from tools.infer.utility import draw_ocr
import copy
import numpy as np
import math
import time
import sys
import os


class TextDetector(object):
Expand Down Expand Up @@ -152,7 +155,7 @@ def __call__(self, img):

if __name__ == "__main__":
args = utility.parse_args()
image_file_list = utility.get_image_file_list(args.image_dir)
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextDetector(args)
count = 0
total_time = 0
Expand All @@ -166,5 +169,14 @@ def __call__(self, img):
total_time += elapse
count += 1
print("Predict time of %s:" % image_file, elapse)
utility.draw_text_det_res(dt_boxes, image_file)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
draw_img = draw_ocr(img, dt_boxes, None, None, False)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1])
print("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file))))
print("Avg Time:", total_time / (count - 1))
4 changes: 2 additions & 2 deletions tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def resize_img(img, input_size=600):
def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
from PIL import Image, ImageDraw, ImageFont

w, h = image.size
img = image.copy()
draw = ImageDraw.Draw(img)

if scores is None:
scores = [1] * len(boxes)
for (box, score) in zip(boxes, scores):
if score < drop_score:
continue
Expand Down
27 changes: 18 additions & 9 deletions tools/infer_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def set_paddle_flags(**kwargs):
)

from paddle import fluid
from ppocr.utils.utility import create_module
from ppocr.utils.utility import create_module, get_image_file_list
import program
from ppocr.utils.save_load import init_model
from ppocr.data.reader_main import reader_main
Expand All @@ -50,20 +50,18 @@ def set_paddle_flags(**kwargs):
logger = initial_logger()


def draw_det_res(dt_boxes, config, img_name, ino):
def draw_det_res(dt_boxes, config, img, img_name):
if len(dt_boxes) > 0:
img_set_path = config['TestReader']['img_set_dir']
img_path = img_set_path + img_name
import cv2
src_im = cv2.imread(img_path)
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
save_det_path = os.path.basename(config['Global'][
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name))
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))

Expand Down Expand Up @@ -103,8 +101,12 @@ def main():
raise Exception("{} not exists!".format(checkpoints))

save_res_path = config['Global']['save_res_path']
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
with open(save_res_path, "wb") as fout:

test_reader = reader_main(config=config, mode='test')
# image_file_list = get_image_file_list(args.image_dir)
tackling_num = 0
for data in test_reader():
img_num = len(data)
Expand All @@ -128,7 +130,13 @@ def main():
postprocess_params.update(global_params)
postprocess = create_module(postprocess_params['function'])\
(params=postprocess_params)
dt_boxes_list = postprocess({"maps": outs[0]}, ratio_list)
if config['Global']['algorithm'] == 'EAST':
dic = {'f_score': outs[0], 'f_geo': outs[1]}
elif config['Global']['algorithm'] == 'DB':
dic = {'maps': outs[0]}
else:
raise Exception("only support algorithm: ['EAST', 'BD']")
dt_boxes_list = postprocess(dic, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
img_name = img_name_list[ino]
Expand All @@ -139,7 +147,8 @@ def main():
dt_boxes_json.append(tmp_json)
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
draw_det_res(dt_boxes, config, img_name, ino)
src_img = cv2.imread(img_name)
draw_det_res(dt_boxes, config, src_img, img_name)

logger.info("success!")

Expand Down