Skip to content

Commit

Permalink
fix: enable type-checking and fix things
Browse files Browse the repository at this point in the history
includes these squashed commits:
 - fix: 3.7 compatibility
 - fix: type-ignore bad type annotation in click
 - fix: properly test normalize function
  • Loading branch information
dhdaines authored and joanise committed Sep 12, 2024
1 parent 5631210 commit 16668b2
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 31 deletions.
8 changes: 4 additions & 4 deletions g2p/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ async def convert(sid, message):


@SIO.on("table event", namespace="/table") # type: ignore
async def change_table(sid, message):
async def change_table(sid, message) -> None:
"""Change the lookup table"""
LOGGER.debug("/table: %s", message)
if "in_lang" not in message or "out_lang" not in message:
Expand All @@ -250,7 +250,7 @@ async def change_table(sid, message):
elif message["in_lang"] == "custom" or message["out_lang"] == "custom":
# These are only used to generate JSON to send to the client,
# so it's safe to create a list of references to the same thing.
mappings = [
mapping_dicts = [
{"in": "", "out": "", "context_before": "", "context_after": ""}
] * DEFAULT_N
abbs = [[""] * 6] * DEFAULT_N
Expand All @@ -272,7 +272,7 @@ async def change_table(sid, message):
"table response",
[
{
"mappings": mappings,
"mappings": mapping_dicts,
"abbs": abbs,
"kwargs": kwargs,
}
Expand All @@ -292,7 +292,7 @@ async def change_table(sid, message):
{
"mappings": x.plain_mapping(),
"abbs": expand_abbreviations_format(x.abbreviations),
"kwargs": x.model_dump(exclude=["alignments"]),
"kwargs": x.model_dump(exclude={"alignments"}),
}
for x in mappings
],
Expand Down
6 changes: 3 additions & 3 deletions g2p/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def generate_mapping( # noqa: C901
from_langs,
to_langs,
distance,
):
) -> None:
"""Generate a new mapping from existing mappings in the g2p system.
This command has different modes of operation.
Expand Down Expand Up @@ -354,7 +354,7 @@ def generate_mapping( # noqa: C901
except MappingMissing as e:
raise click.BadParameter(
f'Cannot find IPA mapping from "{in_lang}" to "{out_lang}": {e}',
param_hint=["IN_LANG", "OUT_LANG"],
param_hint=("IN_LANG", "OUT_LANG"), # type: ignore
)
source_mappings.append(source_mapping)

Expand Down Expand Up @@ -769,7 +769,7 @@ def update_schema(out_dir):
context_settings=CONTEXT_SETTINGS,
short_help="Scan a document for unknown characters.",
)
def scan(lang, path):
def scan(lang, path) -> None:
"""Scan a document for non target language characters.
Displays the set of un-mapped characters in a document.
Expand Down
4 changes: 3 additions & 1 deletion g2p/mappings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
class Mapping(_MappingModelDefinition):
"""Class for lookup tables"""

rules: list

def model_post_init(self, *_args, **_kwargs) -> None:
"""After the model is constructed, we process the model specs by
applying all the configuration to the rules (ie prevent feeding,
Expand Down Expand Up @@ -146,7 +148,7 @@ def plain_mapping(self):
"""
return [rule.export_to_dict() for rule in self.rules]

def process_model_specs(self): # noqa: C901
def process_model_specs(self) -> List[Rule]: # noqa: C901
"""Process all model specifications"""
if self.as_is is not None:
appropriate_setting = (
Expand Down
38 changes: 22 additions & 16 deletions g2p/mappings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Pattern, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Pattern, Tuple, TypeVar, Union, cast

import regex as re
import yaml
Expand All @@ -27,6 +27,7 @@
field_validator,
model_validator,
)
from typing_extensions import Literal

from g2p import exceptions
from g2p.log import LOGGER
Expand Down Expand Up @@ -126,26 +127,27 @@ def expand_abbreviations_format(data):
return lines


def normalize(inp: str, norm_form: str):
def normalize(inp: str, norm_form: Union[str, None]):
"""Normalize to NFC(omposed) or NFD(ecomposed).
Also, find any Unicode Escapes & decode 'em!
"""
if norm_form not in ["none", "NFC", "NFD", "NFKC", "NFKD"]:
raise exceptions.InvalidNormalization(normalize)
elif norm_form is None or norm_form == "none":
if norm_form is None or norm_form == "none":
return unicode_escape(inp)
else:
normalized = ud.normalize(norm_form, unicode_escape(inp))
if normalized != inp:
LOGGER.debug(
"The string %s was normalized to %s using the %s standard and by decoding any Unicode escapes. "
"Note that this is not necessarily the final stage of normalization.",
inp,
normalized,
norm_form,
)
return normalized
if norm_form not in ["NFC", "NFD", "NFKC", "NFKD"]:
raise exceptions.InvalidNormalization(normalize)
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFC", "NFD", "NFKC", "NFKD"], norm_form)
normalized = ud.normalize(norm_form, unicode_escape(inp))
if normalized != inp:
LOGGER.debug(
"The string %s was normalized to %s using the %s standard and by decoding any Unicode escapes. "
"Note that this is not necessarily the final stage of normalization.",
inp,
normalized,
norm_form,
)
return normalized


# compose_indices is generic because we would like to propagate the
Expand Down Expand Up @@ -177,6 +179,8 @@ def normalize_to_NFD_with_indices(
) -> Tuple[str, List[Tuple[int, int]]]:
"""Normalize to NFD and return the indices mapping input to output characters"""
assert norm_form in ("NFD", "NFKD")
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFD", "NFKD"], norm_form)
result = ""
indices = []
for i, c in enumerate(inp):
Expand All @@ -192,6 +196,8 @@ def normalize_to_NFC_with_indices(
) -> Tuple[str, List[Tuple[int, int]]]:
"""Normalize to NFC and return the indices mapping input to output characters"""
assert norm_form in ("NFC", "NFKC")
# Sadly mypy doesn't do narrowing to literals properly
norm_form = cast(Literal["NFC", "NFKC"], norm_form)
inp_nfc = ud.normalize(norm_form, inp)
NFD_form = norm_form[:-1] + "D" # NFC->NFD or NFKC->NFKD
inp_nfd, indices_to_nfd = normalize_to_NFD_with_indices(inp, NFD_form)
Expand Down
11 changes: 10 additions & 1 deletion g2p/tests/test_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from pydantic import ValidationError

from g2p import exceptions
from g2p.exceptions import InvalidNormalization
from g2p.log import LOGGER
from g2p.mappings import Mapping, Rule
from g2p.mappings.utils import NORM_FORM_ENUM, RULE_ORDERING_ENUM
from g2p.mappings.utils import NORM_FORM_ENUM, RULE_ORDERING_ENUM, normalize
from g2p.tests.public import __file__ as public_data
from g2p.transducer import Transducer

Expand Down Expand Up @@ -57,6 +58,14 @@ def test_normalization(self):
self.assertEqual(self.test_mapping_no_norm.rules[1].rule_input, "\u0061\u0301")
self.assertEqual(self.test_mapping_no_norm.rules[1].rule_output, "\u0061\u0301")

def test_utils_normalize(self):
"""Explicitly test our custom normalize function."""
self.assertEqual(normalize(r"\u0061", None), "a")
self.assertEqual(normalize("\u010d", "NFD"), "\u0063\u030c")
self.assertEqual(normalize("\u0063\u030c", "NFC"), "\u010d")
with self.assertRaises(InvalidNormalization):
normalize("FOOBIE", "BLETCH")

def test_json_map(self):
json_map = Mapping(
rules=self.json_map["map"],
Expand Down
12 changes: 6 additions & 6 deletions g2p/tests/time_panphon.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ def getPanphonDistanceSingleton1():

def getPanphonDistanceSingleton2():
if not hasattr(getPanphonDistanceSingleton2, "value"):
setattr(getPanphonDistanceSingleton2, "value", panphon.distance.Distance())
getPanphonDistanceSingleton2.value = panphon.distance.Distance()
return getPanphonDistanceSingleton2.value


for iters in (1, 1, 10, 100, 1000, 10000):
with CodeTimer(f"getPanphonDistanceSingleton1() {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = getPanphonDistanceSingleton1()
with CodeTimer(f"getPanphonDistanceSingleton2() {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = getPanphonDistanceSingleton2()

for words in (1, 10):
Expand All @@ -53,14 +53,14 @@ def getPanphonDistanceSingleton2():

with CodeTimer(f"is_panphon() on 1 word {words} times"):
string = "ei"
for i in range(words):
for _ in range(words):
is_panphon(string)

for iters in (1, 10):
with CodeTimer(f"dst init {iters} times"):
for i in range(iters):
for _ in range(iters):
dst = panphon.distance.Distance()

for iters in (1, 10, 100, 1000):
with CodeTimer(f"Transducer(Mapping(id='panphon_preprocessor')) {iters} times"):
panphon_preprocessor = Transducer(Mapping(id="panphon_preprocessor"))
panphon_preprocessor = Transducer(Mapping(id="panphon_preprocessor", rules=[]))

0 comments on commit 16668b2

Please sign in to comment.