From 5dd26bf365a20b05b0290dacd01701c189541516 Mon Sep 17 00:00:00 2001 From: baberabb <92168766+baberabb@users.noreply.github.com> Date: Wed, 29 May 2024 00:07:38 +0500 Subject: [PATCH] use pandoc --- uspto/download_preprocess.py | 45 ---------------- uspto/uspto-to-dolma.py | 101 ++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 93 deletions(-) delete mode 100644 uspto/download_preprocess.py diff --git a/uspto/download_preprocess.py b/uspto/download_preprocess.py deleted file mode 100644 index bb8ae55..0000000 --- a/uspto/download_preprocess.py +++ /dev/null @@ -1,45 +0,0 @@ -from copy import copy -from typing import Literal -from unicodedata import normalize - -import lxml -import lxml.html as html -import pypandoc -import requests - -from licensed_pile.logs import get_logger - -logger = get_logger("uspto") - - -def convert_mathml_to_latex(url: str, mathml_string: str) -> str: - """Function to convert MathML to LaTeX using a REST server running https://github.com/asnunes/mathml-to-latex.""" - if not mathml_string: - return "" - # response = requests.post(url, json={"mathml": mathml_string}) - try: - response = pypandoc.convert_text(mathml_string, "latex", format="html") - return response - except RuntimeError as e: - logger.info(f"Error converting MathML to LaTeX: {e}") - return mathml_string - # if response.status_code in [400, 500]: - # return str(mathml_string) - # else: - # result = response.json() - # return result.get("latex", mathml_string) - - -def parse_html(url, html_string: str) -> str: - html_string = html.fromstring(html_string) - # equations: list[html.HtmlElement] = html_string.xpath("//maths") - # if equations: - # for i, eq in enumerate(equations): - # new_equation = convert_mathml_to_latex(url, lxml.html.tostring(eq, "unicode")) - # eq.clear() - # eq.text = new_equation - return normalize( - pypandoc.convert_text( - lxml.html.tostring(html_string, encoding="unicode"), "markdown", "html" - ) - ) diff --git a/uspto/uspto-to-dolma.py b/uspto/uspto-to-dolma.py index 2fb2edd..d4edd14 100644 --- a/uspto/uspto-to-dolma.py +++ b/uspto/uspto-to-dolma.py @@ -1,5 +1,6 @@ import argparse import multiprocessing +import re import sys from functools import partial from itertools import islice @@ -7,12 +8,12 @@ from typing import Iterable, Iterator import polars as pl -from download_preprocess import parse_html +import pypandoc from polars import col from tqdm import tqdm from licensed_pile.licenses import PermissiveLicenses -from licensed_pile.logs import configure_logging, get_logger +from licensed_pile.logs import configure_logging from licensed_pile.write import to_dolma logger = configure_logging("uspto") @@ -27,19 +28,32 @@ def batched(iterable, n): yield batch +def parse_html(html_string: str) -> str: + if not html_string: + return "" + text = pypandoc.convert_text(html_string, "plain", "html", extra_args=["--quiet"]) + return re.sub(r"(? pl.Series: + if max_concurrency == 0: + max_concurrency = None + with multiprocessing.get_context("spawn").Pool(4) as pool: + return pl.Series(pool.imap(parse_html, column)) + + def process_datasets( data_dir: str = r"./data/uspto/", - url: str = r"http://localhost:3000/convert", limit: int = 0, max_concurrency: int = 4, ) -> Iterable[dict]: """ This function `run_dataset` scans a dataset located in a directory, converts each file in the dataset to a desired - format using an API endpoint,and returns an iterable of dictionaries containing the converted data. + format using pandoc,and returns an iterable of dictionaries containing the converted data. Parameters: - `data_dir` (str): The directory where the dataset is located. Default value is "./data/uspto/". - - `url` (str): The API endpoint URL for converting the dataset files. Default value is "http://localhost:3000/convert". - `limit` (int): The maximum number of rows to convert. Default value is 0, which means convert all rows from all files in the dataset. - `max_concurrency` (int): The maximum number of concurrent conversions to perform. Default value is 2. @@ -48,13 +62,12 @@ def process_datasets( Note: - The `data_dir` parameter should be a valid directory path ending with a forward slash '/'. - - The `url` parameter should be a valid API endpoint URL. - The `limit` parameter determines how many row to read. Set it to 0 to convert all files. - The `max_concurrency` parameter determines how many parquet files to process concurrently. Example usage: ```python - for data in run_dataset(data_dir=r"./data/uspto/", url="http://localhost:3000/convert", limit=10, max_concurrency=2): + for data in run_dataset(data_dir=r"./data/uspto/", limit=10, max_concurrency=2): # Process each converted data entry print(data) ``` @@ -62,18 +75,10 @@ def process_datasets( data_path = Path(data_dir) logger.info(f"Processing files in {data_path}") file_names = list(data_path.glob("*.parquet")) - if limit > 0: - limit //= len(file_names) - logger.info(f"Processing {limit} entries each from {len(file_names)} files.") - args = [(x, url, limit) for x in file_names] - # we'll let polars handle the row parallelism but the API calls are IO bound - # so, we can increase the number of files to process concurrently - # yield from scan_dataset(args[0]).iter_rows(named=True) - with multiprocessing.get_context("spawn").Pool(2) as pool: - for batch in batched(args, max_concurrency): - logger.debug("Processing files %s", [b[0] for b in batch]) - for res in pool.imap_unordered(scan_dataset, batch): - yield from res.iter_rows(named=True) + for file_name in file_names: + yield from scan_dataset((file_name, limit, max_concurrency)).iter_rows( + named=True + ) def scan_dataset(args: tuple) -> pl.DataFrame: @@ -81,19 +86,20 @@ def scan_dataset(args: tuple) -> pl.DataFrame: Scans an individual parquet file and returns a processed DataFrame. Parameters: - args (tuple): A tuple containing the file name, URL, and limit. + args (tuple): A tuple containing the file name, limit and max_concurrency. Returns: DataFrame: A processed DataFrame containing the selected columns from the dataset. Example Usage: file_name = "dataset.parquet" - url = "https://www.example.com/dataset" limit = 100 + max_concurrency = 4 - result = scan_dataset((file_name, url, limit)) + result = scan_dataset((file_name, limit, max_concurrency)) """ - file_name, url, limit = args + file_name, limit, max_concurrency = args + parallel_apply_ = partial(parallel_apply, max_concurrency) columns = ( "title_text", "title_language", @@ -105,7 +111,6 @@ def scan_dataset(args: tuple) -> pl.DataFrame: "filing_date", ) - html_fn = partial(parse_html, url) df: pl.LazyFrame = ( pl.scan_parquet(file_name) .select(columns) @@ -121,13 +126,20 @@ def scan_dataset(args: tuple) -> pl.DataFrame: pl.lit("2024-03-22", dtype=pl.String).alias("added"), col("created").cast(pl.String, strict=False), col("publication_date").cast(pl.String, strict=False), - col("description_html").map_elements( - html_fn, return_dtype=pl.String, strategy="threading" - ), - col("claims_html").map_elements( - html_fn, return_dtype=pl.String, strategy="threading" - ), - # if abstract returns `ABSTRACT\n\n`. Null otherwise + col("description_html") + .map_batches( + parallel_apply_, + return_dtype=pl.String, + is_elementwise=True, + ) + .str.replace_all(r"\\left(\.|)|\\right(\.|)", ""), + col("claims_html") + .map_batches( + parallel_apply_, + return_dtype=pl.String, + is_elementwise=True, + ) + .str.replace_all(r"\\left(\.|)|\\right(\.|)", ""), pl.concat_str( pl.lit(r"ABSTRACT", dtype=pl.String), pl.lit("\n\n", dtype=pl.String), @@ -140,6 +152,7 @@ def scan_dataset(args: tuple) -> pl.DataFrame: col("title_text"), pl.lit("\n\n", dtype=pl.String), col("abstract_text"), + pl.lit("\n\n", dtype=pl.String), col("description_html"), col("claims_html"), ignore_nulls=True, @@ -153,7 +166,6 @@ def scan_dataset(args: tuple) -> pl.DataFrame: def serialize_dolma( data_dir: str = r"./data/uspto/", - url=r"http://localhost:3000/convert", limit: int = 0, max_concurrency: int = 4, ) -> Iterator[dict[str, str]]: @@ -162,7 +174,6 @@ def serialize_dolma( Args: data_dir: The directory path where the dataset files are located. Default is `./data/uspto/`. - url: The URL of the server to which the serialized documents will be sent. Default is `http://localhost:3000/convert`. limit: The maximum number of documents to be serialized. Default is 0, which represents no limit. max_concurrency: max files to process in parallel. Default is `4`. @@ -171,10 +182,10 @@ def serialize_dolma( content and metadata in a standardized format. Example Usage: - for document in serialize_dolma(data_dir="./data/uspto/", url="http://localhost:3000/convert", limit=10): + for document in serialize_dolma(data_dir="./data/uspto/", limit=10): print(document) """ - for x in tqdm(process_datasets(data_dir, url, limit, max_concurrency)): + for x in tqdm(process_datasets(data_dir, limit, max_concurrency)): metadata = { "source": "Google Patents Public Data", "metadata": { @@ -191,28 +202,23 @@ def create_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--output_dir", type=str, help="Output directory", default=r"raw" ) - # parser.add_argument( - # "data_dir", - # type=str, - # default=r"/Users/baber/Downloads/untitled folder", - # help="Dataset directory where all parquet files to process are located ", - # ) parser.add_argument( - "--url", + "data_dir", type=str, - help="REST API URL for the Node.js MathML to LaTeX converter", - default=r"http://localhost:3000/convert", + default=r"/uspto/data/", + help="Dataset directory where all parquet files to process are located ", ) + parser.add_argument( "--limit", type=int, - default=100, + default=0, help="Limit the number of rows to read for testing", ) parser.add_argument( "--max-concurrency", type=int, - default=2, + default=8, help="Maximum number of parquet files to process concurrently", ) return parser @@ -222,12 +228,11 @@ def create_args_parser() -> argparse.ArgumentParser: args = create_args_parser().parse_args() logger.info( f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir}, - REST API URL: {args.url}, Limit: {args.limit}, Max Concurrency: {args.max_concurrency}""" + Limit: {args.limit}, Max Concurrency: {args.max_concurrency}""" ) to_dolma( serialize_dolma( data_dir="/Users/baber/Downloads/untitled folder", - url=args.url, limit=args.limit, max_concurrency=args.max_concurrency, ),