Skip to content

Commit

Permalink
add save_analysis_res argument
Browse files Browse the repository at this point in the history
  • Loading branch information
breezedeus committed Apr 9, 2024
1 parent 35a75bb commit 5457439
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
19 changes: 19 additions & 0 deletions pix2text/table_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,22 @@ def recognize(
out_markdown=True,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
img ():
tokens ():
out_objects ():
out_cells ():
out_html ():
out_csv ():
out_markdown ():
**kwargs ():
* save_analysis_res (str): Save the parsed result image in this file; default value is `None`, which means not to save
Returns:
"""
out_formats = {}
if self.str_model is None:
print("No structure model loaded.")
Expand Down Expand Up @@ -196,6 +212,9 @@ def recognize(
self._ocr_texts(img, cells)
if out_cells:
out_formats['cells'] = tables_cells
if kwargs.get('save_analysis_res'):
visualize_cells(img, tables_cells[0], kwargs['save_analysis_res'])

if not (out_html or out_csv):
return out_formats

Expand Down
25 changes: 19 additions & 6 deletions tests/test_table_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import pytest
import os

from pix2text.utils import read_img
from pix2text.ocr_engine import prepare_ocr_engine
from pix2text.table_ocr import TableOCR, visualize_cells
from pix2text.table_ocr import TableOCR


def test_recognize():
Expand All @@ -13,9 +12,16 @@ def test_recognize():
languages = ('en', 'ch_sim')
text_ocr = prepare_ocr_engine(languages, {})
ocr = TableOCR(text_ocr=text_ocr)
result = ocr.recognize(image_path, out_csv=True, out_cells=True, out_objects=False, out_html=True, out_markdown=True)
result = ocr.recognize(
image_path,
out_csv=True,
out_cells=True,
out_objects=False,
out_html=True,
out_markdown=True,
save_analysis_res='out-table-rec.png',
)

visualize_cells(read_img(image_path, 'Image'), result['cells'][0], 'out-table-rec.png')
print(result)


Expand All @@ -25,7 +31,14 @@ def test_recognize2():
languages = ('en', 'ch_sim')
text_ocr = prepare_ocr_engine(languages, {})
ocr = TableOCR.from_config(text_ocr=text_ocr)
result = ocr.recognize(image_path, out_csv=True, out_cells=True, out_objects=False, out_html=True, out_markdown=True)
result = ocr.recognize(
image_path,
out_csv=True,
out_cells=True,
out_objects=False,
out_html=True,
out_markdown=True,
save_analysis_res='out-table-rec.png',
)

visualize_cells(read_img(image_path, 'Image'), result['cells'][0], 'out-table-rec.png')
print(result)

0 comments on commit 5457439

Please sign in to comment.