Skip to content

Commit

Permalink
replace bs4 with lxml. Also using polars for the main computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
baberabb committed May 14, 2024
1 parent 7c2cdff commit 22711a5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 86 deletions.
66 changes: 20 additions & 46 deletions uspto/download_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,27 @@
import bs4
from typing import Literal

import lxml.html as html
import requests


def convert_mathml_to_latex(url: str, mathml_string: str) -> dict[str, str]:
def convert_mathml_to_latex(url: str, mathml_string: str) -> str:
"""Function to convert MathML to LaTeX using a REST server running https://github.com/asnunes/mathml-to-latex."""
if not mathml_string:
return ""
response = requests.post(url, json={"mathml": mathml_string})
result = response.json()
return result

if response.status_code in [400, 500]:
return str(mathml_string)
else:
result = response.json()
return result.get("latex", mathml_string)

def format_text(url: str, example: dict) -> dict[str, str]:
"""
Formats each row to:
Title\n\n
ABSTRACT\n\n
abstract_text\n\n
description
claims

Returned fields: text, date, app_number
"""
output = ""
if title := example.get("title_text"):
output += title + "\n\n"
if abstract := example.get("abstract_text"):
output += (
"ABSTRACT"
+ "\n\n"
+ bs4.BeautifulSoup(abstract, "html.parser").get_text().strip()
+ "\n\n"
)
if description := example.get("description_html"):
description = bs4.BeautifulSoup(description, "html.parser")
equations: list[bs4.element.Tag] = description.find_all("maths")
if equations:
for i, eq in enumerate(equations):
new_equation = convert_mathml_to_latex(url, str(eq))["latex"]
eq.string = new_equation
output += description.get_text()
if claims := example.get("claims_text"):
claims = bs4.BeautifulSoup(claims, "html.parser")
equations = claims.find_all("maths")
if equations:
for i, eq in enumerate(equations):
new_equation = convert_mathml_to_latex(url, str(eq))["latex"]
eq.string = new_equation
output += claims.get_text().strip()
return {
"text": output,
"date": str(example.get("publication_date")),
"app_number": example.get("application_number"),
}
def parse_html(url, html_string: str) -> str:
html_string = html.fromstring(html_string)
equations: list[html.HtmlElement] = html_string.xpath("//maths")
if equations:
for i, eq in enumerate(equations):
new_equation = convert_mathml_to_latex(url, str(eq))
eq.clear()
eq.text = new_equation
return html_string.text_content()
4 changes: 4 additions & 0 deletions uspto/process_uspto.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/bin/bash

# Download HF Dataset repo
git clone git@hf.co:datasets/baber/USPTO ./data/uspto/raw

# Get the current directory
CURRENT_DIR=$(basename "$PWD")

Expand All @@ -9,6 +12,7 @@ if [ "$CURRENT_DIR" != "uspto" ]; then
cd uspto || exit
fi


# Clone MathML to LaTeX converter
git clone https://github.com/baberabb/mathml-to-latex.git

Expand Down
4 changes: 2 additions & 2 deletions uspto/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
datasets
bs4
polars
lxml
123 changes: 85 additions & 38 deletions uspto/uspto-to-dolma.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,105 @@
import argparse
import glob
from functools import partial
from typing import Iterable

import datasets
from download_preprocess import format_text
import polars as pl
from download_preprocess import parse_html
from polars import col

from licensed_pile.licenses import PermissiveLicenses
from licensed_pile.write import to_dolma


def get_dataset(
hf_path: str, cache_dir: str | None = None, streaming: bool = False
) -> datasets.Dataset:
kwargs = dict(split="train")
if cache_dir:
kwargs["cache_dir"] = cache_dir
if streaming:
kwargs["streaming"] = True
uspto_df = datasets.load_dataset(hf_path, **kwargs)
return uspto_df
def scan_dataset(
data_dir: str = r"./data/uspto/",
url=r"http://localhost:3000/convert",
streaming: bool = True,
) -> Iterable[dict]:
"""
Scans the dataset in the specified directory and yields dictionaries representing each row of data.
The data is selected and transformed according to the specified columns.
HTML content in the "description_html" and "claims_html" columns is parsed using the provided local URL.
Parameters:
- data_dir (str): The path to the directory containing the dataset. Defaults to "./data/uspto/".
- url (str): The URL used for parsing HTML content. Defaults to "http://localhost:3000/convert".
- streaming (bool): Do not load the whole dataset in Memory. Defaults to True.
def return_dolma(ds: datasets.Dataset) -> dict[str, str]:
Returns:
Iterable[dict]: An iterable of dictionaries representing each row of data.
Example usage:
```python
for row in scan_dataset(data_dir="./data/", url="http://example.com/"):
print(row)
```
"""
if not data_dir[-1] == r"/":
data_dir += r"/"
html_fn = partial(parse_html, url)

# columns to use
columns = [
"title_text",
"title_language",
"abstract_text",
"description_html",
"claims_html",
"publication_date",
"application_number",
"filing_date",
]
for file_name in glob.glob(data_dir + r"*.parquet"):
df: pl.LazyFrame = (
pl.scan_parquet(file_name)
.select(columns)
.drop_nulls(["abstract_text", "description_html", "claims_html"])
# we use app no. for the id and filing date for the date added to database
.rename({"application_number": "id", "filing_date": "added"})
.with_columns(
col("added").cast(pl.String, strict=False),
col("publication_date").cast(pl.String, strict=False),
col("description_html").map_elements(html_fn, return_dtype=pl.String),
col("claims_html").map_elements(html_fn, return_dtype=pl.String),
# if abstract returns `ABSTRACT\n\n<abstract>`. Null otherwise
pl.concat_str(
pl.lit(r"ABSTRACT", dtype=pl.String),
pl.lit("\n\n", dtype=pl.String),
col("abstract_text"),
ignore_nulls=False,
).alias("abstract_text"),
)
.with_columns(
pl.concat_str(
col("title_text"),
pl.lit("\n\n", dtype=pl.String),
col("abstract_text"),
col("description_html"),
col("claims_html"),
ignore_nulls=True,
).alias("text")
)
).select(["id", "text", "added", "title_language", "publication_date"])

yield from df.collect(streaming=streaming).iter_rows(named=True)


def serialize_dolma(ds: Iterable[dict[str, str]]) -> dict[str, str]:
for x in ds:
output = {
"text": x.get("text"),
"id": x.get("application_number"),
metadata = {
"source": "Google Patents Public Data",
"metadata": {
"license": str(PermissiveLicenses.CC_BY),
"language": x.get("title_language"),
"publication_date": str(x.get("publication_date")),
"language": x.pop("title_language", "en"),
"publication_date": str(x.pop("publication_date", "9999")),
},
}
yield output
yield x | metadata


parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir", type=str, help="Output directory", default=r"/data/uspto/raw"
)
parser.add_argument(
"--dataset", type=str, help="Path to raw HF dataset", default=r"baber/USPTO"
)
parser.add_argument(
"--cache_dir",
type=str,
help="Path to cache HF dataset",
default=r"./data/uspto/raw",
)
parser.add_argument("--output_dir", type=str, help="Output directory", default=r"raw")
parser.add_argument("--streaming", action="store_true")
parser.add_argument(
"--url",
Expand All @@ -58,10 +110,5 @@ def return_dolma(ds: datasets.Dataset) -> dict[str, str]:

if __name__ == "__main__":
args = parser.parse_args()
URL = args.url
DATASET = args.dataset
OUTPUT_DIR = args.output_dir
uspto_df = get_dataset(DATASET, cache_dir=args.cache_dir, streaming=args.streaming)
format_text = partial(format_text, URL)
uspto_df = uspto_df.map(format_text, remove_columns=list(uspto_df.column_names))
to_dolma(return_dolma(uspto_df), OUTPUT_DIR, "uspto.jsonl.gz")
uspto_df = scan_dataset(args.dataset, url=args.url)
to_dolma(serialize_dolma(uspto_df), args.output_dir, "uspto.jsonl.gz")

0 comments on commit 22711a5

Please sign in to comment.