Skip to content

Commit

Permalink
Merge pull request #8 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
ONNX for inference, speed boost, enable flattening PDFs
  • Loading branch information
VikParuchuri authored Oct 7, 2024
2 parents 2557089 + ae08899 commit c4f0d34
Show file tree
Hide file tree
Showing 11 changed files with 756 additions and 623 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pdftext PDF_PATH --out_path output.txt
- `--keep_hyphens` will keep hyphens in the output (they will be stripped and words joined otherwise)
- `--pages` will specify pages (comma separated) to extract
- `--workers` specifies the number of parallel workers to use
- `--flatten_pdf` merges form fields into the PDF

## JSON

Expand All @@ -44,6 +45,7 @@ pdftext PDF_PATH --out_path output.txt --json
- `--pages` will specify pages (comma separated) to extract
- `--keep_chars` will keep individual characters in the json output
- `--workers` specifies the number of parallel workers to use
- `--flatten_pdf` merges form fields into the PDF

The output will be a json list, with each item in the list corresponding to a single page in the input pdf (in order). Each page will include the following keys:

Expand Down
7 changes: 3 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def pdfplumber_inference(pdf_path):
return pages


def pdftext_inference(pdf_path, model=None, workers=None):
return paginated_plain_text_output(pdf_path, model=model, workers=workers)
def pdftext_inference(pdf_path, workers=None):
return paginated_plain_text_output(pdf_path, workers=workers)


def compare_docs(doc1: str, doc2: str):
Expand All @@ -78,7 +78,6 @@ def main():
if args.pdftext_only:
times_tools = ["pymupdf", "pdftext"]
alignment_tools = ["pdftext"]
model = get_model()
for i in tqdm(range(len(dataset)), desc="Benchmarking"):
row = dataset[i]
pdf = row["pdf"]
Expand All @@ -88,7 +87,7 @@ def main():
f.seek(0)
pdf_path = f.name

pdftext_inference_model = partial(pdftext_inference, model=model, workers=args.pdftext_workers)
pdftext_inference_model = partial(pdftext_inference, workers=args.pdftext_workers)
inference_funcs = [pymupdf_inference, pdftext_inference_model, pdfplumber_inference]
for tool, inference_func in zip(times_tools, inference_funcs):
start = time.time()
Expand Down
5 changes: 3 additions & 2 deletions extract_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main():
parser.add_argument("--sort", action="store_true", help="Attempt to sort the text by reading order", default=False)
parser.add_argument("--keep_hyphens", action="store_true", help="Keep hyphens in words", default=False)
parser.add_argument("--pages", type=str, help="Comma separated pages to extract, like 1,2,3", default=None)
parser.add_argument("--flatten_pdf", action="store_true", help="Flatten form fields and annotations into page contents", default=False)
parser.add_argument("--keep_chars", action="store_true", help="Keep character level information", default=False)
parser.add_argument("--workers", type=int, help="Number of workers to use for parallel processing", default=None)
args = parser.parse_args()
Expand All @@ -24,10 +25,10 @@ def main():
assert all(p <= len(pdf_doc) for p in pages), "Invalid page number(s) provided"

if args.json:
text = dictionary_output(args.pdf_path, sort=args.sort, page_range=pages, keep_chars=args.keep_chars, workers=args.workers)
text = dictionary_output(args.pdf_path, sort=args.sort, page_range=pages, flatten_pdf=args.flatten_pdf, keep_chars=args.keep_chars, workers=args.workers)
text = json.dumps(text)
else:
text = plain_text_output(args.pdf_path, sort=args.sort, hyphens=args.keep_hyphens, page_range=pages, workers=args.workers)
text = plain_text_output(args.pdf_path, sort=args.sort, hyphens=args.keep_hyphens, page_range=pages, flatten_pdf=args.flatten_pdf, workers=args.workers)

if args.out_path is None:
print(text)
Expand Down
Binary file added models/dt.onnx
Binary file not shown.
53 changes: 35 additions & 18 deletions pdftext/extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import List
from concurrent.futures import ProcessPoolExecutor
import math
Expand All @@ -12,49 +11,67 @@
from pdftext.settings import settings


def _get_page_range(pdf_path, model, page_range):
pdf_doc = pdfium.PdfDocument(pdf_path)
text_chars = get_pdfium_chars(pdf_doc, page_range)
def _load_pdf(pdf, flatten_pdf):
if isinstance(pdf, str):
pdf = pdfium.PdfDocument(pdf)

if not isinstance(pdf, pdfium.PdfDocument):
raise TypeError("pdf must be a file path string or a PdfDocument object")

# Must be called on the parent pdf, before the page was retrieved
if flatten_pdf:
pdf.init_forms()

return pdf


def _get_page_range(page_range, flatten_pdf=False):
text_chars = get_pdfium_chars(pdf_doc, page_range, flatten_pdf)
pages = inference(text_chars, model)
return pages


def _get_pages(pdf_path, model=None, page_range=None, workers=None):
if model is None:
model = get_model()
def worker_init(pdf_path, flatten_pdf):
global model
global pdf_doc

pdf_doc = _load_pdf(pdf_path, flatten_pdf)
model = get_model()

pdf_doc = pdfium.PdfDocument(pdf_path)

def _get_pages(pdf_path, page_range=None, flatten_pdf=False, workers=None):
pdf_doc = _load_pdf(pdf_path, flatten_pdf)
if page_range is None:
page_range = range(len(pdf_doc))

if workers is not None:
workers = min(workers, len(page_range) // settings.WORKER_PAGE_THRESHOLD) # It's inefficient to have too many workers, since we batch in inference

if workers is None or workers <= 1:
text_chars = get_pdfium_chars(pdf_doc, page_range)
model = get_model()
text_chars = get_pdfium_chars(pdf_doc, page_range, flatten_pdf)
return inference(text_chars, model)

func = partial(_get_page_range, pdf_path, model)
page_range = list(page_range)

pages_per_worker = math.ceil(len(page_range) / workers)
page_range_chunks = [page_range[i * pages_per_worker:(i + 1) * pages_per_worker] for i in range(workers)]

with ProcessPoolExecutor(max_workers=workers) as executor:
pages = list(executor.map(func, page_range_chunks))
with ProcessPoolExecutor(max_workers=workers, initializer=worker_init, initargs=(pdf_path, flatten_pdf)) as executor:
pages = list(executor.map(_get_page_range, page_range_chunks))

ordered_pages = [page for sublist in pages for page in sublist]

return ordered_pages


def plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, model=model, hyphens=hyphens, page_range=page_range, workers=workers)
def plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flatten_pdf=False, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, hyphens=hyphens, page_range=page_range, workers=workers, flatten_pdf=flatten_pdf)
return "\n".join(text)


def paginated_plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> List[str]:
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def paginated_plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flatten_pdf=False, workers=None) -> List[str]:
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
text = []
for page in pages:
text.append(merge_text(page, sort=sort, hyphens=hyphens).strip())
Expand All @@ -71,8 +88,8 @@ def _process_span(span, page_width, page_height, keep_chars):
char["bbox"] = unnormalize_bbox(char["bbox"], page_width, page_height)


def dictionary_output(pdf_path, sort=False, model=None, page_range=None, keep_chars=False, workers=None):
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def dictionary_output(pdf_path, sort=False, page_range=None, keep_chars=False, flatten_pdf=False, workers=None):
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
for page in pages:
page_width, page_height = page["width"], page["height"]
for block in page["blocks"]:
Expand Down
56 changes: 27 additions & 29 deletions pdftext/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from itertools import chain

import sklearn
import numpy as np

from pdftext.pdf.utils import LINE_BREAKS, TABS, SPACES
from pdftext.settings import settings
Expand Down Expand Up @@ -47,27 +45,27 @@ def create_training_row(char_info, prev_char, currblock, currline):

is_space = char in SPACES or char in TABS

training_row = {
"is_newline": char in LINE_BREAKS,
"is_space": is_space,
"x_gap": x_gap,
"y_gap": y_gap,
"font_match": font_match,
"x_outer_gap": char_x2 - prev_x1,
"y_outer_gap": char_y2 - prev_y1,
"line_x_center_gap": char_center_x - currline["center_x"],
"line_y_center_gap": char_center_y - currline["center_y"],
"line_x_gap": char_x1 - currline_bbox[2],
"line_y_gap": char_y1 - currline_bbox[3],
"line_x_start_gap": char_x1 - currline_bbox[0],
"line_y_start_gap": char_y1 - currline_bbox[1],
"block_x_center_gap": char_center_x - currblock["center_x"],
"block_y_center_gap": char_center_y - currblock["center_y"],
"block_x_gap": char_x1 - currblock_bbox[2],
"block_y_gap": char_y1 - currblock_bbox[3],
"block_x_start_gap": char_x1 - currblock_bbox[0],
"block_y_start_gap": char_y1 - currblock_bbox[1]
}
return np.array([
char_center_x - currblock["center_x"],
char_x1 - currblock_bbox[2],
char_x1 - currblock_bbox[0],
char_center_y - currblock["center_y"],
char_y1 - currblock_bbox[3],
char_y1 - currblock_bbox[1],
font_match,
char in LINE_BREAKS,
is_space,
char_center_x - currline["center_x"],
char_x1 - currline_bbox[2],
char_x1 - currline_bbox[0],
char_center_y - currline["center_y"],
char_y1 - currline_bbox[3],
char_y1 - currline_bbox[1],
x_gap,
char_x2 - prev_x1,
y_gap,
char_y2 - prev_y1
], dtype=np.float32)

return training_row

Expand Down Expand Up @@ -135,8 +133,6 @@ def infer_single_page(text_chars, block_threshold=settings.BLOCK_THRESHOLD):
font_info = f"{font['name']}_{font['size']}_{font['weight']}_{font['flags']}_{char_info['rotation']}"
if prev_char:
training_row = create_training_row(char_info, prev_char, block, line)
sorted_keys = sorted(training_row.keys())
training_row = [training_row[key] for key in sorted_keys]

prediction_probs = yield training_row
# First item is probability of same line/block, second is probability of new line, third is probability of new block
Expand Down Expand Up @@ -175,6 +171,8 @@ def inference(text_chars, model):
# Create generators and get first training row from each
generators = [infer_single_page(text_page) for text_page in text_chars]
next_prediction = {}
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[1].name

page_blocks = {}
while len(page_blocks) < len(generators):
Expand All @@ -199,10 +197,10 @@ def inference(text_chars, model):

training_idxs = sorted(training_data.keys())
training_rows = [training_data[idx] for idx in training_idxs]
training_rows = np.stack(training_rows, axis=0)

# Disable nan, etc, validation for a small speedup
with sklearn.config_context(assume_finite=True):
predictions = model.predict_proba(training_rows)
# Run inference
predictions = model.run([output_name], {input_name: training_rows})[0]
for pred, page_idx in zip(predictions, training_idxs):
next_prediction[page_idx] = pred
sorted_keys = sorted(page_blocks.keys())
Expand Down
6 changes: 3 additions & 3 deletions pdftext/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import joblib
from pdftext.settings import settings
import onnxruntime as rt


def get_model(model_path=settings.MODEL_PATH):
model = joblib.load(model_path)
return model
sess = rt.InferenceSession(model_path)
return sess
18 changes: 17 additions & 1 deletion pdftext/pdf/chars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List

import pypdfium2.raw as pdfium_c
from pypdfium2 import PdfiumError

from pdftext.pdf.utils import get_fontname, pdfium_page_bbox_to_device_bbox, page_bbox_to_device_bbox
from pdftext.settings import settings
Expand All @@ -19,11 +20,26 @@ def update_previous_fonts(char_infos: List, i: int, prev_fontname: str, prev_fon
char_infos[j]["font"]["flags"] = fontflags


def get_pdfium_chars(pdf, page_range, fontname_sample_freq=settings.FONTNAME_SAMPLE_FREQ):
def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
rc = pdfium_c.FPDFPage_Flatten(page, flag)
if rc == pdfium_c.FLATTEN_FAIL:
raise PdfiumError("Failed to flatten annotations / form fields.")


def get_pdfium_chars(pdf, page_range, flatten_pdf, fontname_sample_freq=settings.FONTNAME_SAMPLE_FREQ):
blocks = []

for page_idx in page_range:
page = pdf.get_page(page_idx)

if flatten_pdf:
# Flatten form fields and annotations into page contents.
flatten(pdf, page)

# Flattening invalidates existing handles to the page.
# It is necessary to re-initialize the page handle after flattening.
page = pdf.get_page(page_idx)

text_page = page.get_textpage()
mediabox = page.get_mediabox()
page_rotation = page.get_rotation()
Expand Down
2 changes: 1 addition & 1 deletion pdftext/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Settings(BaseSettings):
BASE_PATH: str = os.path.dirname(os.path.dirname(__file__))
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.joblib")
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.onnx")

# Fonts
FONTNAME_SAMPLE_FREQ: int = 4
Expand Down
Loading

0 comments on commit c4f0d34

Please sign in to comment.