diff --git a/numba_typing/overload_list.py b/numba_typing/overload_list.py new file mode 100644 index 000000000..de29d2b6b --- /dev/null +++ b/numba_typing/overload_list.py @@ -0,0 +1,193 @@ +import numba +from numba import types +from numba.extending import overload +from type_annotations import product_annotations, get_func_annotations +import typing +from numba.typed import List, Dict +from inspect import getfullargspec + + +def overload_list(orig_func): + def overload_inner(ovld_list): + def wrapper(*args): + func_list = ovld_list() + sig_list = [] + for func in func_list: + sig_list.append((product_annotations( + get_func_annotations(func)), func)) + args_orig_func = getfullargspec(orig_func) + values_dict = {name: typ for name, typ in zip(args_orig_func.args, args)} + defaults_dict = {} + if args_orig_func.defaults: + defaults_dict = {name: value for name, value in zip( + args_orig_func.args[::-1], args_orig_func.defaults[::-1])} + if valid_signature(sig_list, values_dict, defaults_dict): + result = choose_func_by_sig(sig_list, values_dict) + + if result is None: + raise TypeError(f'Unsupported types {args}') + + return result + + return overload(orig_func, strict=False)(wrapper) + + return overload_inner + + +def valid_signature(list_signature, values_dict, defaults_dict): + def check_defaults(list_param, sig_def): + for name, val in defaults_dict.items(): + if sig_def.get(name) is None: + raise AttributeError(f'{name} does not match the signature of the function passed to overload_list') + if sig_def[name] != val: + raise ValueError(f'The default arguments are not equal: {name}: {val} != {sig_def[name]}') + if type(sig_def[name]) != list_param[name]: + raise TypeError(f'The default value does not match the type: {list_param[name]}') + + for sig, _ in list_signature: + for param in sig.parameters: + if len(param) != len(values_dict): + check_defaults(param, sig.defaults) + + return True + + +def check_int_type(n_type): + return isinstance(n_type, types.Integer) + + +def check_float_type(n_type): + return isinstance(n_type, types.Float) + + +def check_bool_type(n_type): + return isinstance(n_type, types.Boolean) + + +def check_str_type(n_type): + return isinstance(n_type, types.UnicodeType) + + +def check_list_type(self, p_type, n_type): + res = isinstance(n_type, (types.List, types.ListType)) + if p_type == list: + return res + else: + return res and self.match(p_type.__args__[0], n_type.dtype) + + +def check_tuple_type(self, p_type, n_type): + if not isinstance(n_type, (types.Tuple, types.UniTuple)): + return False + try: + if len(p_type.__args__) != len(n_type.types): + return False + except AttributeError: # if p_type == tuple + return True + + for p_val, n_val in zip(p_type.__args__, n_type.types): + if not self.match(p_val, n_val): + return False + + return True + + +def check_dict_type(self, p_type, n_type): + res = False + if isinstance(n_type, types.DictType): + res = True + if isinstance(p_type, type): + return res + for p_val, n_val in zip(p_type.__args__, n_type.keyvalue_type): + res = res and self.match(p_val, n_val) + return res + + +class TypeChecker: + + _types_dict: dict = {} + + def __init__(self): + self._typevars_dict = {} + + def clear_typevars_dict(self): + self._typevars_dict.clear() + + @classmethod + def add_type_check(cls, type_check, func): + cls._types_dict[type_check] = func + + @staticmethod + def _is_generic(p_obj): + if isinstance(p_obj, typing._GenericAlias): + return True + + if isinstance(p_obj, typing._SpecialForm): + return p_obj not in {typing.Any} + + return False + + @staticmethod + def _get_origin(p_obj): + return p_obj.__origin__ + + def match(self, p_type, n_type): + if p_type == typing.Any: + return True + try: + if self._is_generic(p_type): + origin_type = self._get_origin(p_type) + if origin_type == typing.Generic: + return self.match_generic(p_type, n_type) + + return self._types_dict[origin_type](self, p_type, n_type) + + if isinstance(p_type, typing.TypeVar): + return self.match_typevar(p_type, n_type) + + if p_type in (list, tuple, dict): + return self._types_dict[p_type](self, p_type, n_type) + + return self._types_dict[p_type](n_type) + + except KeyError: + raise TypeError(f'A check for the {p_type} was not found.') + + def match_typevar(self, p_type, n_type): + if isinstance(n_type, types.List): + n_type = types.ListType(n_type.dtype) + if not self._typevars_dict.get(p_type): + self._typevars_dict[p_type] = n_type + return True + return self._typevars_dict.get(p_type) == n_type + + def match_generic(self, p_type, n_type): + raise SystemError + + +TypeChecker.add_type_check(int, check_int_type) +TypeChecker.add_type_check(float, check_float_type) +TypeChecker.add_type_check(str, check_str_type) +TypeChecker.add_type_check(bool, check_bool_type) +TypeChecker.add_type_check(list, check_list_type) +TypeChecker.add_type_check(tuple, check_tuple_type) +TypeChecker.add_type_check(dict, check_dict_type) + + +def choose_func_by_sig(sig_list, values_dict): + def check_signature(sig_params, types_dict): + checker = TypeChecker() + for name, typ in types_dict.items(): # name,type = 'a',int64 + if isinstance(typ, types.Literal): + typ = typ.literal_type + if not checker.match(sig_params[name], typ): + return False + + return True + + for sig, func in sig_list: # sig = (Signature,func) + for param in sig.parameters: # param = {'a':int,'b':int} + if check_signature(param, values_dict): + return func + + return None diff --git a/numba_typing/test_overload_list.py b/numba_typing/test_overload_list.py new file mode 100644 index 000000000..6566ce327 --- /dev/null +++ b/numba_typing/test_overload_list.py @@ -0,0 +1,228 @@ +import overload_list +from overload_list import List, Dict +from overload_list import types +import unittest +import typing +from numba import njit, core +import re + + +T = typing.TypeVar('T') +K = typing.TypeVar('K') +S = typing.TypeVar('S', int, float) +UserType = typing.NewType('UserType', int) + + +def generator_test(param, values_dict, defaults_dict={}): + + def check_type(typ): + if isinstance(typ, type): + return typ.__name__ + return typ + + value_keys = ", ".join(f"{key}" if key not in defaults_dict.keys() + else f"{key} = {defaults_dict[key]}" for key in values_dict.keys()) + value_annotation = ", ".join(f"{key}: {check_type(val)}" if key not in defaults_dict.keys() + else f"{key}: {check_type(val)} = {defaults_dict[key]}" + for key, val in values_dict.items()) + value_type = ", ".join(f"{val}" for val in values_dict.values()) + return_value_keys = ", ".join("{}".format(key) for key in values_dict.keys()) + param_func = ", ".join(f"{val}" for val in param) + test = f""" +def test_myfunc(): + def foo({value_keys}): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo({value_annotation}): + return ("{value_type}") + + return (foo,) + + @njit + def jit_func({value_keys}): + return foo({return_value_keys}) + + return (jit_func({param_func}), ("{value_type}")) +""" + loc = {} + exec(test, globals(), loc) + return loc + + +list_numba = List([1, 2, 3]) +nested_list_numba = List([List([1, 2])]) +dict_numba = Dict.empty(key_type=types.unicode_type, value_type=types.int64) +dict_numba_1 = Dict.empty(key_type=types.int64, value_type=types.boolean) +dict_numba['qwe'] = 1 +dict_numba_1[1] = True +list_type = types.ListType(types.int64) +list_in_dict_numba = Dict.empty(key_type=types.unicode_type, value_type=list_type) +list_in_dict_numba['qwe'] = List([3, 4, 5]) +str_variable = 'qwe' +str_variable_1 = 'qaz' +user_type = UserType(1) + + +def run_test(case): + run_generator = generator_test(*case) + received, expected = run_generator['test_myfunc']() + return (received, expected) + + +def run_test_with_error(case): + run_generator = generator_test(*case) + try: + run_generator['test_myfunc']() + except core.errors.TypingError as err: + res = re.search(r'TypeError', err.msg) + return res.group(0) + + +class TestOverload(unittest.TestCase): + maxDiff = None + + def test_standart_types(self): + test_cases = [([1], {'a': int}), ([1.0], {'a': float}), ([True], {'a': bool}), (['str_variable'], {'a': str})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_container_types(self): + test_cases = [([[1, 2]], {'a': list}), (['list_numba'], {'a': list}), ([(1.0, 2.0)], {'a': tuple}), + ([(1, 2.0)], {'a': tuple}), (['dict_numba'], {'a': dict})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_typing_types(self): + test_cases = [([[1.0, 2.0]], {'a': typing.List[float]}), (['list_numba'], {'a': typing.List[int]}), + ([(1, 2.0)], {'a': typing.Tuple[int, float]}), (['dict_numba_1'], {'a': typing.Dict[int, bool]}), + ([False], {'a': typing.Union[bool, str]}), (['str_variable'], {'a': typing.Union[bool, str]}), + ([True, 'str_variable'], {'a': typing.Union[bool, str], 'b': typing.Union[bool, str]}), + (['str_variable', True], {'a': typing.Union[bool, str], 'b': typing.Union[bool, str]}), + ([1, False], {'a': typing.Any, 'b': typing.Any}), + ([1.0, 'str_variable'], {'a': typing.Any, 'b': typing.Any})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_nested_typing_types(self): + test_cases = [(['nested_list_numba'], {'a': typing.List[typing.List[int]]}), + ([((1.0,),)], {'a': typing.Tuple[typing.Tuple[float]]})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_typevar_types(self): + test_cases = [([1.0], {'a': 'T'}), ([False], {'a': 'T'}), ([1, 2], {'a': 'T', 'b': 'T'}), + ([1.0, 2.0], {'a': 'T', 'b': 'T'}), (['str_variable', 'str_variable'], {'a': 'T', 'b': 'T'}), + (['list_numba', [1, 2]], {'a': 'T', 'b': 'T'}), + ([1, 2.0], {'a': 'T', 'b': 'K'}), ([1], {'a': 'S'}), ([1.0], {'a': 'S'}), + ([[True, True]], {'a': 'typing.List[T]'}), (['list_numba'], {'a': 'typing.List[T]'}), + ([('str_variable', 2)], {'a': 'typing.Tuple[T,K]'}), + (['dict_numba_1'], {'a': 'typing.Dict[K, T]'}), (['dict_numba'], {'a': 'typing.Dict[K, T]'}), + (['list_in_dict_numba'], {'a': 'typing.Dict[K, typing.List[T]]'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_only_default_types(self): + test_cases = [([], {'a': int}, {'a': 1}), ([], {'a': float}, {'a': 1.0}), ([], {'a': bool}, {'a': True}), + ([], {'a': str}, {'a': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_overriding_default_types(self): + test_cases = [([5], {'a': int}, {'a': 1}), ([5.0], {'a': float}, {'a': 1.0}), + ([False], {'a': bool}, {'a': True}), (['str_variable_1'], {'a': str}, {'a': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_two_types(self): + test_cases = [([5, 3.0], {'a': int, 'b': float}), ([5, 3.0], {'a': int, 'b': float}, {'b': 0.0}), + ([5], {'a': int, 'b': float}, {'b': 0.0}), ([], {'a': int, 'b': float}, {'a': 0, 'b': 0.0})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_three_types(self): + test_cases = [([5, 3.0, 'str_variable_1'], {'a': int, 'b': float, 'c': str}), + ([5, 3.0], {'a': int, 'b': float, 'c': str}, {'c': 'str_variable'}), + ([5], {'a': int, 'b': float, 'c': str}, {'b': 0.0, 'c': 'str_variable'}), + ([], {'a': int, 'b': float, 'c': str}, {'a': 0, 'b': 0.0, 'c': 'str_variable'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(*run_test(case)) + + def test_type_error(self): + test_cases = [([1], {'a': float}), ([], {'a': float}, {'a': 1}), ([1], {'a': typing.Iterable[int]}), + ([(1, 2, 3)], {'a': typing.Tuple[int, int]}), ([(1.0, 2)], {'a': typing.Tuple[int, int]}), + ([(1, 2.0)], {'a': typing.Tuple[int, int]}), ([(1.0, 2.0)], {'a': typing.Tuple[int, int]}), + ([1, 2.0], {'a': 'T', 'b': 'T'}), ([(1, 2), (1, 2.0)], {'a': 'T', 'b': 'T'}), + ([True], {'a': 'S'}), (['str_variable'], {'a': 'S'})] + + for case in test_cases: + with self.subTest(case=case): + self.assertEqual(run_test_with_error(case), 'TypeError') + + def test_attribute_error(self): + def foo(a=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo(a: int): + return (a,) + + return (foo,) + + @njit + def jit_func(): + return foo() + + try: + jit_func() + except core.errors.TypingError as err: + res = re.search(r'AttributeError', err.msg) + self.assertEqual(res.group(0), 'AttributeError') + + def test_value_error(self): + def foo(a=0): + ... + + @overload_list.overload_list(foo) + def foo_ovld_list(): + + def foo(a: int = 1): + return (a,) + + return (foo,) + + @njit + def jit_func(): + return foo() + + try: + jit_func() + except core.errors.TypingError as err: + res = re.search(r'ValueError', err.msg) + self.assertEqual(res.group(0), 'ValueError') + + +if __name__ == "__main__": + unittest.main()