Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
nnnyt authored Mar 19, 2024
2 parents 6d5191c + 7abc7d1 commit e05f640
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 29 deletions.
6 changes: 4 additions & 2 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

[Shangzi Xue](https://github.com/ShangziXue)

[Chaokun Wang](https://github.com/Bone-Fish)
[Heng Yu](https://github.com/GNEHUY)

[Tianyun Ji](https://github.com/KINGNEWBLUSH)

The stared contributors are the corresponding authors.
[Chaokun Wang](https://github.com/Bone-Fish)
156 changes: 156 additions & 0 deletions EduNLP/SIF/parser/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# coding: utf-8
# 2024/3/5 @ yuheng
import json
import requests
from EduNLP.utils import image2base64


class FormulaRecognitionError(Exception):
"""Exception raised when formula recognition fails."""
def __init__(self, message="Formula recognition failed"):
self.message = message
super().__init__(self.message)


def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
"""
Recognizes mathematical formulas in an image and returns their LaTeX representation.
Parameters
----------
image_PIL_or_base64 : PngImageFile or str
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
is_base64 : bool, optional
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string.
Returns
-------
latex : str
The LaTeX representation of the mathematical formula recognized in the image.
Raises an exception if the image is not recognized as containing a mathematical formula.
Raises
------
FormulaRecognitionError
If the HTTP request does not return a 200 status code,
if there is an error processing the response,
if the image is not recognized as a mathematical formula.
Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> print(ocr_formula_figure(image_PIL))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> print(ocr_formula_figure(image_base64, is_base64=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
Notes
-----
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
"""
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"

if is_base64:
image = image_PIL_or_base64
else:
image = image2base64(image_PIL_or_base64)

data = [{
'qid': 0,
'image': image
}]

resp = requests.post(url, data=json.dumps(data))

if resp.status_code != 200:
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")

try:
res = json.loads(resp.content)
except Exception as e:
raise FormulaRecognitionError(f"Error processing response: {e}")

res = json.loads(resp.content)
data = res['data']
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
latex = data['latex']
else:
latex = None
raise FormulaRecognitionError("Image is not recognized as a formula")

return latex


def ocr(src, is_base64=False, figure_instances: dict = None):
"""
Recognizes mathematical formulas within figures from a given source,
which can be either a base64 string or an identifier for a figure within a provided dictionary.
Parameters
----------
src : str
The source from which the figure is to be recognized.
It can be a base64 encoded string of the image if is_base64 is True,
or an identifier for the figure if is_base64 is False.
is_base64 : bool, optional
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
figure_instances : dict, optional
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None.
This is only required and used if is_base64 is False.
Returns
-------
forumla_figure_latex : str or None
The LaTeX representation of the mathematical formula recognized within the figure.
Returns None if no formula is recognized or
if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.
Examples
--------
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> figure_instances = {"1": image_PIL}
>>> src_id = r"$\\FormFigureID{1}$"
>>> print(ocr(src_id[1:-1], figure_instances=figure_instances))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
>>> import os
>>> from PIL import Image
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
>>> image_base64 = image2base64(image_PIL)
>>> src_base64 = r"$\\FormFigureBase64{%s}$" % (image_base64)
>>> print(ocr(src_base64[1:-1], is_base64=True, figure_instances=True))
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
Notes
-----
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process.
Ensure that `ocr_formula_figure` is correctly implemented and can handle base64 encoded strings and PngImageFile.
"""
forumla_figure_latex = None
if is_base64:
figure = src[len(r"\FormFigureBase64") + 1: -1]
if figure_instances is not None:
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
else:
figure = src[len(r"\FormFigureID") + 1: -1]
if figure_instances is not None:
figure = figure_instances[figure]
forumla_figure_latex = ocr_formula_figure(figure, is_base64)

return forumla_figure_latex
17 changes: 12 additions & 5 deletions EduNLP/SIF/segment/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
from contextlib import contextmanager
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
from ..parser.ocr import ocr


class TextSegment(str):
Expand Down Expand Up @@ -93,7 +94,7 @@ class SegmentList(object):
>>> SegmentList(test_item)
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
"""
def __init__(self, item, figures: dict = None):
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
self._segments = []
self._text_segments = []
self._formula_segments = []
Expand All @@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
if not re.match(r"\$.+?\$", segment):
self.append(TextSegment(segment))
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
if convert_image_to_latex:
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
else:
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
Expand Down Expand Up @@ -271,7 +278,7 @@ def describe(self):
}


def seg(item, figures=None, symbol=None):
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
r"""
It is a interface for SegmentList. And show it in an appropriate way.
Expand Down Expand Up @@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
>>> s2.text_segments
['已知', ',则以下说法中正确的是']
"""
segments = SegmentList(item, figures)
segments = SegmentList(item, figures, convert_image_to_latex)
if symbol is not None:
segments.symbolize(symbol)
return segments
4 changes: 2 additions & 2 deletions EduNLP/SIF/sif.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def to_sif(item, check_formula=True, parser: Parser = None):


def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str = None, tokenization=True,
tokenization_params=None, errors="raise"):
tokenization_params=None, convert_image_to_latex=False, errors="raise"):
r"""
Default to use linear Tokenizer, change the tokenizer by specifying tokenization_params
Expand Down Expand Up @@ -260,7 +260,7 @@ def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str
"Unknown mode %s, use only 0 or 1 or 2." % mode
)

ret = seg(item, figures, symbol)
ret = seg(item, figures, symbol, convert_image_to_latex)

if tokenization is True:
ret = tokenize(ret, **(tokenization_params if tokenization_params is not None else {}))
Expand Down
92 changes: 78 additions & 14 deletions EduNLP/SIF/tokenization/text/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
# 2021/5/18 @ tongshiwei
import logging
import jieba
from nltk.tokenize import word_tokenize
import nltk
import spacy
import tokenizers as huggingface_tokenizer
from tokenizers.trainers import BpeTrainer
from .stopwords import DEFAULT_STOPWORDS
from tokenizers import Tokenizer as HGTokenizer


jieba.setLogLevel(logging.INFO)

Expand All @@ -15,7 +22,13 @@ def is_chinese(word):
return True


def tokenize(text, granularity="word", stopwords="default"):
def tokenize(text,
granularity="word",
stopwords="default",
tokenizer="jieba",
tok_model="en_core_web_sm",
bpe_json='bpe.tokenizer.json',
bpe_trainfile=None):
"""
Using jieba library to tokenize item by word or char.
Expand All @@ -37,17 +50,68 @@ def tokenize(text, granularity="word", stopwords="default"):
"""
stopwords = DEFAULT_STOPWORDS if stopwords == "default" else stopwords
stopwords = stopwords if stopwords is not None else {}
if granularity == "word":
return [token for token in jieba.cut(text) if token not in stopwords and token.strip()]
elif granularity == "char":
jieba_tokens = [token for token in jieba.cut(text) if token not in stopwords and token.strip()]
# Use jieba_tokens to hangle sentence with mixed chinese and english.
split_tokens = []
for token in jieba_tokens:
if is_chinese(token):
split_tokens.extend(list(token))
else:
split_tokens.append(token)
return split_tokens

if (tokenizer == 'jieba'):
if granularity == "word":
return [
token for token in jieba.cut(text)
if token not in stopwords and token.strip()
]
elif granularity == "char":
jieba_tokens = [
token for token in jieba.cut(text)
if token not in stopwords and token.strip()
]
# Use jieba_tokens to hangle sentence with mixed chinese and english.
split_tokens = []
for token in jieba_tokens:
if is_chinese(token):
split_tokens.extend(list(token))
else:
split_tokens.append(token)
return split_tokens
else:
raise TypeError("Unknown granularity %s" % granularity)

elif (tokenizer == 'nltk'):
try:
return [
token for token in word_tokenize(text)
if token not in stopwords and token.strip()
]
except LookupError:
nltk.download('punkt')
return [
token for token in word_tokenize(text)
if token not in stopwords and token.strip()
]

elif (tokenizer == 'spacy'):
try:
spacy_tokenizer = spacy.load(tok_model)
except OSError:
spacy.cli.download(tok_model)
spacy_tokenizer = spacy.load(tok_model)
output = spacy_tokenizer(str(text))
return [
token.text for token in output
if token.text not in stopwords
]

elif (tokenizer == 'bpe'):
try:
tokenizer = HGTokenizer.from_file(bpe_json)
except Exception:
tokenizer = huggingface_tokenizer.Tokenizer(
huggingface_tokenizer.models.BPE())
if (bpe_trainfile is None):
raise LookupError("bpe train file not found, using %s." % bpe_trainfile)
trainer = BpeTrainer(
special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])
tokenizer.train(files=[bpe_trainfile], trainer=trainer)
tokenizer.save(bpe_json, pretty=True)
output = tokenizer.encode(text)
output = output.tokens
return output[0]
else:
raise TypeError("Unknown granularity %s" % granularity)
raise TypeError("Invalid Spliter: %s" % tokenizer)
Loading

0 comments on commit e05f640

Please sign in to comment.