Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed May 28, 2024
1 parent 5815c0b commit 099abbb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 29 deletions.
44 changes: 31 additions & 13 deletions uspto/download_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,45 @@
from copy import copy
from typing import Literal
from unicodedata import normalize

import lxml
import lxml.html as html
import pypandoc
import requests

from licensed_pile.logs import get_logger

logger = get_logger("uspto")


def convert_mathml_to_latex(url: str, mathml_string: str) -> str:
"""Function to convert MathML to LaTeX using a REST server running https://github.com/asnunes/mathml-to-latex."""
if not mathml_string:
return ""
response = requests.post(url, json={"mathml": mathml_string})
if response.status_code in [400, 500]:
return str(mathml_string)
else:
result = response.json()
return result.get("latex", mathml_string)
# response = requests.post(url, json={"mathml": mathml_string})
try:
response = pypandoc.convert_text(mathml_string, "latex", format="html")
return response
except RuntimeError as e:
logger.info(f"Error converting MathML to LaTeX: {e}")
return mathml_string
# if response.status_code in [400, 500]:
# return str(mathml_string)
# else:
# result = response.json()
# return result.get("latex", mathml_string)


def parse_html(url, html_string: str) -> str:
html_string = html.fromstring(html_string)
equations: list[html.HtmlElement] = html_string.xpath("//maths")
if equations:
for i, eq in enumerate(equations):
new_equation = convert_mathml_to_latex(url, str(eq))
eq.clear()
eq.text = new_equation
return html_string.text_content()
# equations: list[html.HtmlElement] = html_string.xpath("//maths")
# if equations:
# for i, eq in enumerate(equations):
# new_equation = convert_mathml_to_latex(url, lxml.html.tostring(eq, "unicode"))
# eq.clear()
# eq.text = new_equation
return normalize(
pypandoc.convert_text(
lxml.html.tostring(html_string, encoding="unicode"), "markdown", "html"
)
)
38 changes: 22 additions & 16 deletions uspto/uspto-to-dolma.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import polars as pl
from download_preprocess import parse_html
from polars import col
from tqdm import tqdm

from licensed_pile.licenses import PermissiveLicenses
from licensed_pile.logs import configure_logging
from licensed_pile.logs import configure_logging, get_logger
from licensed_pile.write import to_dolma

logger = configure_logging("uspto")
Expand Down Expand Up @@ -67,7 +68,8 @@ def process_datasets(
args = [(x, url, limit) for x in file_names]
# we'll let polars handle the row parallelism but the API calls are IO bound
# so, we can increase the number of files to process concurrently
with multiprocessing.get_context("spawn").Pool() as pool:
# yield from scan_dataset(args[0]).iter_rows(named=True)
with multiprocessing.get_context("spawn").Pool(2) as pool:
for batch in batched(args, max_concurrency):
logger.debug("Processing files %s", [b[0] for b in batch])
for res in pool.imap_unordered(scan_dataset, batch):
Expand Down Expand Up @@ -119,8 +121,12 @@ def scan_dataset(args: tuple) -> pl.DataFrame:
pl.lit("2024-03-22", dtype=pl.String).alias("added"),
col("created").cast(pl.String, strict=False),
col("publication_date").cast(pl.String, strict=False),
col("description_html").map_elements(html_fn, return_dtype=pl.String),
col("claims_html").map_elements(html_fn, return_dtype=pl.String),
col("description_html").map_elements(
html_fn, return_dtype=pl.String, strategy="threading"
),
col("claims_html").map_elements(
html_fn, return_dtype=pl.String, strategy="threading"
),
# if abstract returns `ABSTRACT\n\n<abstract>`. Null otherwise
pl.concat_str(
pl.lit(r"ABSTRACT", dtype=pl.String),
Expand All @@ -142,7 +148,7 @@ def scan_dataset(args: tuple) -> pl.DataFrame:
).select(["id", "text", "added", "created", "title_language", "publication_date"])
if limit > 0:
df = df.fetch(limit).lazy()
return df.collect()
return df.collect(streaming=True)


def serialize_dolma(
Expand All @@ -168,7 +174,7 @@ def serialize_dolma(
for document in serialize_dolma(data_dir="./data/uspto/", url="http://localhost:3000/convert", limit=10):
print(document)
"""
for x in process_datasets(data_dir, url, limit, max_concurrency):
for x in tqdm(process_datasets(data_dir, url, limit, max_concurrency)):
metadata = {
"source": "Google Patents Public Data",
"metadata": {
Expand All @@ -185,12 +191,12 @@ def create_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--output_dir", type=str, help="Output directory", default=r"raw"
)
parser.add_argument(
"data_dir",
type=str,
default=r"./data/uspto/",
help="Dataset directory where all parquet files to process are located ",
)
# parser.add_argument(
# "data_dir",
# type=str,
# default=r"/Users/baber/Downloads/untitled folder",
# help="Dataset directory where all parquet files to process are located ",
# )
parser.add_argument(
"--url",
type=str,
Expand All @@ -200,13 +206,13 @@ def create_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--limit",
type=int,
default=0,
default=100,
help="Limit the number of rows to read for testing",
)
parser.add_argument(
"--max-concurrency",
type=int,
default=4,
default=2,
help="Maximum number of parquet files to process concurrently",
)
return parser
Expand All @@ -215,12 +221,12 @@ def create_args_parser() -> argparse.ArgumentParser:
if __name__ == "__main__":
args = create_args_parser().parse_args()
logger.info(
f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir}, Data Dir: {args.data_dir},
f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir},
REST API URL: {args.url}, Limit: {args.limit}, Max Concurrency: {args.max_concurrency}"""
)
to_dolma(
serialize_dolma(
data_dir=args.data_dir,
data_dir="/Users/baber/Downloads/untitled folder",
url=args.url,
limit=args.limit,
max_concurrency=args.max_concurrency,
Expand Down

0 comments on commit 099abbb

Please sign in to comment.