diff --git a/uspto/download_preprocess.py b/uspto/download_preprocess.py index f4abce9..bb8ae55 100644 --- a/uspto/download_preprocess.py +++ b/uspto/download_preprocess.py @@ -1,27 +1,45 @@ +from copy import copy from typing import Literal +from unicodedata import normalize +import lxml import lxml.html as html +import pypandoc import requests +from licensed_pile.logs import get_logger + +logger = get_logger("uspto") + def convert_mathml_to_latex(url: str, mathml_string: str) -> str: """Function to convert MathML to LaTeX using a REST server running https://github.com/asnunes/mathml-to-latex.""" if not mathml_string: return "" - response = requests.post(url, json={"mathml": mathml_string}) - if response.status_code in [400, 500]: - return str(mathml_string) - else: - result = response.json() - return result.get("latex", mathml_string) + # response = requests.post(url, json={"mathml": mathml_string}) + try: + response = pypandoc.convert_text(mathml_string, "latex", format="html") + return response + except RuntimeError as e: + logger.info(f"Error converting MathML to LaTeX: {e}") + return mathml_string + # if response.status_code in [400, 500]: + # return str(mathml_string) + # else: + # result = response.json() + # return result.get("latex", mathml_string) def parse_html(url, html_string: str) -> str: html_string = html.fromstring(html_string) - equations: list[html.HtmlElement] = html_string.xpath("//maths") - if equations: - for i, eq in enumerate(equations): - new_equation = convert_mathml_to_latex(url, str(eq)) - eq.clear() - eq.text = new_equation - return html_string.text_content() + # equations: list[html.HtmlElement] = html_string.xpath("//maths") + # if equations: + # for i, eq in enumerate(equations): + # new_equation = convert_mathml_to_latex(url, lxml.html.tostring(eq, "unicode")) + # eq.clear() + # eq.text = new_equation + return normalize( + pypandoc.convert_text( + lxml.html.tostring(html_string, encoding="unicode"), "markdown", "html" + ) + ) diff --git a/uspto/uspto-to-dolma.py b/uspto/uspto-to-dolma.py index 53ce537..2fb2edd 100644 --- a/uspto/uspto-to-dolma.py +++ b/uspto/uspto-to-dolma.py @@ -9,9 +9,10 @@ import polars as pl from download_preprocess import parse_html from polars import col +from tqdm import tqdm from licensed_pile.licenses import PermissiveLicenses -from licensed_pile.logs import configure_logging +from licensed_pile.logs import configure_logging, get_logger from licensed_pile.write import to_dolma logger = configure_logging("uspto") @@ -67,7 +68,8 @@ def process_datasets( args = [(x, url, limit) for x in file_names] # we'll let polars handle the row parallelism but the API calls are IO bound # so, we can increase the number of files to process concurrently - with multiprocessing.get_context("spawn").Pool() as pool: + # yield from scan_dataset(args[0]).iter_rows(named=True) + with multiprocessing.get_context("spawn").Pool(2) as pool: for batch in batched(args, max_concurrency): logger.debug("Processing files %s", [b[0] for b in batch]) for res in pool.imap_unordered(scan_dataset, batch): @@ -119,8 +121,12 @@ def scan_dataset(args: tuple) -> pl.DataFrame: pl.lit("2024-03-22", dtype=pl.String).alias("added"), col("created").cast(pl.String, strict=False), col("publication_date").cast(pl.String, strict=False), - col("description_html").map_elements(html_fn, return_dtype=pl.String), - col("claims_html").map_elements(html_fn, return_dtype=pl.String), + col("description_html").map_elements( + html_fn, return_dtype=pl.String, strategy="threading" + ), + col("claims_html").map_elements( + html_fn, return_dtype=pl.String, strategy="threading" + ), # if abstract returns `ABSTRACT\n\n`. Null otherwise pl.concat_str( pl.lit(r"ABSTRACT", dtype=pl.String), @@ -142,7 +148,7 @@ def scan_dataset(args: tuple) -> pl.DataFrame: ).select(["id", "text", "added", "created", "title_language", "publication_date"]) if limit > 0: df = df.fetch(limit).lazy() - return df.collect() + return df.collect(streaming=True) def serialize_dolma( @@ -168,7 +174,7 @@ def serialize_dolma( for document in serialize_dolma(data_dir="./data/uspto/", url="http://localhost:3000/convert", limit=10): print(document) """ - for x in process_datasets(data_dir, url, limit, max_concurrency): + for x in tqdm(process_datasets(data_dir, url, limit, max_concurrency)): metadata = { "source": "Google Patents Public Data", "metadata": { @@ -185,12 +191,12 @@ def create_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--output_dir", type=str, help="Output directory", default=r"raw" ) - parser.add_argument( - "data_dir", - type=str, - default=r"./data/uspto/", - help="Dataset directory where all parquet files to process are located ", - ) + # parser.add_argument( + # "data_dir", + # type=str, + # default=r"/Users/baber/Downloads/untitled folder", + # help="Dataset directory where all parquet files to process are located ", + # ) parser.add_argument( "--url", type=str, @@ -200,13 +206,13 @@ def create_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--limit", type=int, - default=0, + default=100, help="Limit the number of rows to read for testing", ) parser.add_argument( "--max-concurrency", type=int, - default=4, + default=2, help="Maximum number of parquet files to process concurrently", ) return parser @@ -215,12 +221,12 @@ def create_args_parser() -> argparse.ArgumentParser: if __name__ == "__main__": args = create_args_parser().parse_args() logger.info( - f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir}, Data Dir: {args.data_dir}, + f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir}, REST API URL: {args.url}, Limit: {args.limit}, Max Concurrency: {args.max_concurrency}""" ) to_dolma( serialize_dolma( - data_dir=args.data_dir, + data_dir="/Users/baber/Downloads/untitled folder", url=args.url, limit=args.limit, max_concurrency=args.max_concurrency,