From 46df29b97bdac0708f72a28b8e6c76680d0fa5a9 Mon Sep 17 00:00:00 2001 From: Erez Shinan Date: Fri, 20 Oct 2023 21:17:46 +0700 Subject: [PATCH] Typing fixes. Mypy now produces 0 type errors Also adding some typing info --- lark/load_grammar.py | 43 +++++++++++++++++++++++-------------------- lark/utils.py | 5 +++-- lark/visitors.py | 4 ++-- 3 files changed, 28 insertions(+), 24 deletions(-) diff --git a/lark/load_grammar.py b/lark/load_grammar.py index 8e41775f..362a845d 100644 --- a/lark/load_grammar.py +++ b/lark/load_grammar.py @@ -9,10 +9,10 @@ import pkgutil from ast import literal_eval from contextlib import suppress -from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence +from typing import List, Tuple, Union, Callable, Dict, Optional, Sequence, Generator from .utils import bfs, logger, classify_bool, is_id_continue, is_id_start, bfs_all_unique, small_factors, OrderedSet -from .lexer import Token, TerminalDef, PatternStr, PatternRE +from .lexer import Token, TerminalDef, PatternStr, PatternRE, Pattern from .parse_tree_builder import ParseTreeBuilder from .parser_frontends import ParsingFrontend @@ -195,10 +195,10 @@ class FindRuleSize(Transformer): - def __init__(self, keep_all_tokens): + def __init__(self, keep_all_tokens: bool): self.keep_all_tokens = keep_all_tokens - def _will_not_get_removed(self, sym): + def _will_not_get_removed(self, sym: Symbol) -> bool: if isinstance(sym, NonTerminal): return not sym.name.startswith('_') if isinstance(sym, Terminal): @@ -207,7 +207,7 @@ def _will_not_get_removed(self, sym): return False assert False, sym - def _args_as_int(self, args): + def _args_as_int(self, args: List[Union[int, Symbol]]) -> Generator[int, None, None]: for a in args: if isinstance(a, int): yield a @@ -216,10 +216,10 @@ def _args_as_int(self, args): else: assert False - def expansion(self, args): + def expansion(self, args) -> int: return sum(self._args_as_int(args)) - def expansions(self, args): + def expansions(self, args) -> int: return max(self._args_as_int(args)) @@ -232,7 +232,7 @@ def __init__(self): self.i = 0 self.rule_options = None - def _name_rule(self, inner): + def _name_rule(self, inner: str): new_name = '__%s_%s_%d' % (self.prefix, inner, self.i) self.i += 1 return new_name @@ -243,7 +243,7 @@ def _add_rule(self, key, name, expansions): self.rules_cache[key] = t return t - def _add_recurse_rule(self, type_, expr): + def _add_recurse_rule(self, type_: str, expr: Tree): try: return self.rules_cache[expr] except KeyError: @@ -312,7 +312,7 @@ def _add_repeat_opt_rule(self, a, b, target, target_opt, atom): ]) return self._add_rule(key, new_name, tree) - def _generate_repeats(self, rule, mn, mx): + def _generate_repeats(self, rule: Tree, mn: int, mx: int): """Generates a rule tree that repeats ``rule`` exactly between ``mn`` to ``mx`` times. """ # For a small number of repeats, we can take the naive approach @@ -343,7 +343,7 @@ def _generate_repeats(self, rule, mn, mx): return ST('expansions', [ST('expansion', [mn_target] + [diff_opt_target])]) - def expr(self, rule, op, *args): + def expr(self, rule: Tree, op: Token, *args): if op.value == '?': empty = ST('expansion', []) return ST('expansions', [rule, empty]) @@ -372,7 +372,7 @@ def expr(self, rule, op, *args): assert False, op - def maybe(self, rule): + def maybe(self, rule: Tree): keep_all_tokens = self.rule_options and self.rule_options.keep_all_tokens rule_size = FindRuleSize(keep_all_tokens).transform(rule) empty = ST('expansion', [_EMPTY] * rule_size) @@ -382,11 +382,11 @@ def maybe(self, rule): class SimplifyRule_Visitor(Visitor): @staticmethod - def _flatten(tree): + def _flatten(tree: Tree): while tree.expand_kids_by_data(tree.data): pass - def expansion(self, tree): + def expansion(self, tree: Tree): # rules_list unpacking # a : b (c|d) e # --> @@ -417,7 +417,7 @@ def alias(self, tree): tree.data = 'expansions' tree.children = aliases - def expansions(self, tree): + def expansions(self, tree: Tree): self._flatten(tree) # Ensure all children are unique if len(set(tree.children)) != len(tree.children): @@ -610,7 +610,7 @@ def range(self, start, end): return ST('pattern', [PatternRE(regexp)]) -def _make_joined_pattern(regexp, flags_set): +def _make_joined_pattern(regexp, flags_set) -> PatternRE: return PatternRE(regexp, ()) class TerminalTreeToPattern(Transformer_NonRecursive): @@ -618,15 +618,17 @@ def pattern(self, ps): p ,= ps return p - def expansion(self, items): - assert items + def expansion(self, items: List[Pattern]) -> Pattern: + if not items: + return PatternStr('') + if len(items) == 1: return items[0] pattern = ''.join(i.to_regexp() for i in items) return _make_joined_pattern(pattern, {i.flags for i in items}) - def expansions(self, exps): + def expansions(self, exps: List[Pattern]) -> Pattern: if len(exps) == 1: return exps[0] @@ -637,7 +639,8 @@ def expansions(self, exps): pattern = '(?:%s)' % ('|'.join(i.to_regexp() for i in exps)) return _make_joined_pattern(pattern, {i.flags for i in exps}) - def expr(self, args): + def expr(self, args) -> Pattern: + inner: Pattern inner, op = args[:2] if op == '~': if len(args) == 3: diff --git a/lark/utils.py b/lark/utils.py index 97db7119..cc6183b2 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -181,7 +181,7 @@ def is_id_start(s: str) -> bool: return _test_unicode_category(s, _ID_START) -def dedup_list(l: List[T]) -> List[T]: +def dedup_list(l: Sequence[T]) -> List[T]: """Given a list (l) will removing duplicates from the list, preserving the original order of the list. Assumes that the list entries are hashable.""" @@ -231,7 +231,8 @@ def combine_alternatives(lists): return list(product(*lists)) try: - import atomicwrites + # atomicwrites doesn't have type bindings + import atomicwrites # type: ignore[import] _has_atomicwrites = True except ImportError: _has_atomicwrites = False diff --git a/lark/visitors.py b/lark/visitors.py index f7324180..ae9d128c 100644 --- a/lark/visitors.py +++ b/lark/visitors.py @@ -267,7 +267,7 @@ def __mul__( return TransformerChain(*self.transformers + (other,)) -class Transformer_InPlace(Transformer): +class Transformer_InPlace(Transformer[_Leaf_T, _Return_T]): """Same as Transformer, but non-recursive, and changes the tree in-place instead of returning new instances Useful for huge trees. Conservative in memory. @@ -282,7 +282,7 @@ def transform(self, tree: Tree[_Leaf_T]) -> _Return_T: return self._transform_tree(tree) -class Transformer_NonRecursive(Transformer): +class Transformer_NonRecursive(Transformer[_Leaf_T, _Return_T]): """Same as Transformer but non-recursive. Like Transformer, it doesn't change the original tree.