Skip to content

Commit

Permalink
use pandoc
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed May 28, 2024
1 parent 099abbb commit 5dd26bf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 93 deletions.
45 changes: 0 additions & 45 deletions uspto/download_preprocess.py

This file was deleted.

101 changes: 53 additions & 48 deletions uspto/uspto-to-dolma.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import argparse
import multiprocessing
import re
import sys
from functools import partial
from itertools import islice
from pathlib import Path
from typing import Iterable, Iterator

import polars as pl
from download_preprocess import parse_html
import pypandoc
from polars import col
from tqdm import tqdm

from licensed_pile.licenses import PermissiveLicenses
from licensed_pile.logs import configure_logging, get_logger
from licensed_pile.logs import configure_logging
from licensed_pile.write import to_dolma

logger = configure_logging("uspto")
Expand All @@ -27,19 +28,32 @@ def batched(iterable, n):
yield batch


def parse_html(html_string: str) -> str:
if not html_string:
return ""
text = pypandoc.convert_text(html_string, "plain", "html", extra_args=["--quiet"])
return re.sub(r"(?<!\n)\n(?!\n)", "", text)


# from: https://stackoverflow.com/a/74749075/19355181
def parallel_apply(max_concurrency: int, column: pl.Series) -> pl.Series:
if max_concurrency == 0:
max_concurrency = None
with multiprocessing.get_context("spawn").Pool(4) as pool:
return pl.Series(pool.imap(parse_html, column))


def process_datasets(
data_dir: str = r"./data/uspto/",
url: str = r"http://localhost:3000/convert",
limit: int = 0,
max_concurrency: int = 4,
) -> Iterable[dict]:
"""
This function `run_dataset` scans a dataset located in a directory, converts each file in the dataset to a desired
format using an API endpoint,and returns an iterable of dictionaries containing the converted data.
format using pandoc,and returns an iterable of dictionaries containing the converted data.
Parameters:
- `data_dir` (str): The directory where the dataset is located. Default value is "./data/uspto/".
- `url` (str): The API endpoint URL for converting the dataset files. Default value is "http://localhost:3000/convert".
- `limit` (int): The maximum number of rows to convert. Default value is 0, which means convert all rows from all files in the dataset.
- `max_concurrency` (int): The maximum number of concurrent conversions to perform. Default value is 2.
Expand All @@ -48,52 +62,44 @@ def process_datasets(
Note:
- The `data_dir` parameter should be a valid directory path ending with a forward slash '/'.
- The `url` parameter should be a valid API endpoint URL.
- The `limit` parameter determines how many row to read. Set it to 0 to convert all files.
- The `max_concurrency` parameter determines how many parquet files to process concurrently.
Example usage:
```python
for data in run_dataset(data_dir=r"./data/uspto/", url="http://localhost:3000/convert", limit=10, max_concurrency=2):
for data in run_dataset(data_dir=r"./data/uspto/", limit=10, max_concurrency=2):
# Process each converted data entry
print(data)
```
"""
data_path = Path(data_dir)
logger.info(f"Processing files in {data_path}")
file_names = list(data_path.glob("*.parquet"))
if limit > 0:
limit //= len(file_names)
logger.info(f"Processing {limit} entries each from {len(file_names)} files.")
args = [(x, url, limit) for x in file_names]
# we'll let polars handle the row parallelism but the API calls are IO bound
# so, we can increase the number of files to process concurrently
# yield from scan_dataset(args[0]).iter_rows(named=True)
with multiprocessing.get_context("spawn").Pool(2) as pool:
for batch in batched(args, max_concurrency):
logger.debug("Processing files %s", [b[0] for b in batch])
for res in pool.imap_unordered(scan_dataset, batch):
yield from res.iter_rows(named=True)
for file_name in file_names:
yield from scan_dataset((file_name, limit, max_concurrency)).iter_rows(
named=True
)


def scan_dataset(args: tuple) -> pl.DataFrame:
"""
Scans an individual parquet file and returns a processed DataFrame.
Parameters:
args (tuple): A tuple containing the file name, URL, and limit.
args (tuple): A tuple containing the file name, limit and max_concurrency.
Returns:
DataFrame: A processed DataFrame containing the selected columns from the dataset.
Example Usage:
file_name = "dataset.parquet"
url = "https://www.example.com/dataset"
limit = 100
max_concurrency = 4
result = scan_dataset((file_name, url, limit))
result = scan_dataset((file_name, limit, max_concurrency))
"""
file_name, url, limit = args
file_name, limit, max_concurrency = args
parallel_apply_ = partial(parallel_apply, max_concurrency)
columns = (
"title_text",
"title_language",
Expand All @@ -105,7 +111,6 @@ def scan_dataset(args: tuple) -> pl.DataFrame:
"filing_date",
)

html_fn = partial(parse_html, url)
df: pl.LazyFrame = (
pl.scan_parquet(file_name)
.select(columns)
Expand All @@ -121,13 +126,20 @@ def scan_dataset(args: tuple) -> pl.DataFrame:
pl.lit("2024-03-22", dtype=pl.String).alias("added"),
col("created").cast(pl.String, strict=False),
col("publication_date").cast(pl.String, strict=False),
col("description_html").map_elements(
html_fn, return_dtype=pl.String, strategy="threading"
),
col("claims_html").map_elements(
html_fn, return_dtype=pl.String, strategy="threading"
),
# if abstract returns `ABSTRACT\n\n<abstract>`. Null otherwise
col("description_html")
.map_batches(
parallel_apply_,
return_dtype=pl.String,
is_elementwise=True,
)
.str.replace_all(r"\\left(\.|)|\\right(\.|)", ""),
col("claims_html")
.map_batches(
parallel_apply_,
return_dtype=pl.String,
is_elementwise=True,
)
.str.replace_all(r"\\left(\.|)|\\right(\.|)", ""),
pl.concat_str(
pl.lit(r"ABSTRACT", dtype=pl.String),
pl.lit("\n\n", dtype=pl.String),
Expand All @@ -140,6 +152,7 @@ def scan_dataset(args: tuple) -> pl.DataFrame:
col("title_text"),
pl.lit("\n\n", dtype=pl.String),
col("abstract_text"),
pl.lit("\n\n", dtype=pl.String),
col("description_html"),
col("claims_html"),
ignore_nulls=True,
Expand All @@ -153,7 +166,6 @@ def scan_dataset(args: tuple) -> pl.DataFrame:

def serialize_dolma(
data_dir: str = r"./data/uspto/",
url=r"http://localhost:3000/convert",
limit: int = 0,
max_concurrency: int = 4,
) -> Iterator[dict[str, str]]:
Expand All @@ -162,7 +174,6 @@ def serialize_dolma(
Args:
data_dir: The directory path where the dataset files are located. Default is `./data/uspto/`.
url: The URL of the server to which the serialized documents will be sent. Default is `http://localhost:3000/convert`.
limit: The maximum number of documents to be serialized. Default is 0, which represents no limit.
max_concurrency: max files to process in parallel. Default is `4`.
Expand All @@ -171,10 +182,10 @@ def serialize_dolma(
content and metadata in a standardized format.
Example Usage:
for document in serialize_dolma(data_dir="./data/uspto/", url="http://localhost:3000/convert", limit=10):
for document in serialize_dolma(data_dir="./data/uspto/", limit=10):
print(document)
"""
for x in tqdm(process_datasets(data_dir, url, limit, max_concurrency)):
for x in tqdm(process_datasets(data_dir, limit, max_concurrency)):
metadata = {
"source": "Google Patents Public Data",
"metadata": {
Expand All @@ -191,28 +202,23 @@ def create_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--output_dir", type=str, help="Output directory", default=r"raw"
)
# parser.add_argument(
# "data_dir",
# type=str,
# default=r"/Users/baber/Downloads/untitled folder",
# help="Dataset directory where all parquet files to process are located ",
# )
parser.add_argument(
"--url",
"data_dir",
type=str,
help="REST API URL for the Node.js MathML to LaTeX converter",
default=r"http://localhost:3000/convert",
default=r"/uspto/data/",
help="Dataset directory where all parquet files to process are located ",
)

parser.add_argument(
"--limit",
type=int,
default=100,
default=0,
help="Limit the number of rows to read for testing",
)
parser.add_argument(
"--max-concurrency",
type=int,
default=2,
default=8,
help="Maximum number of parquet files to process concurrently",
)
return parser
Expand All @@ -222,12 +228,11 @@ def create_args_parser() -> argparse.ArgumentParser:
args = create_args_parser().parse_args()
logger.info(
f"""Processing USPTO with the following parameters: Output Dir: {args.output_dir},
REST API URL: {args.url}, Limit: {args.limit}, Max Concurrency: {args.max_concurrency}"""
Limit: {args.limit}, Max Concurrency: {args.max_concurrency}"""
)
to_dolma(
serialize_dolma(
data_dir="/Users/baber/Downloads/untitled folder",
url=args.url,
limit=args.limit,
max_concurrency=args.max_concurrency,
),
Expand Down

0 comments on commit 5dd26bf

Please sign in to comment.