From 3ac1510cc940ff293629c9de8c99d422f60f4c1d Mon Sep 17 00:00:00 2001 From: Yukinari Tani Date: Tue, 30 Aug 2022 22:26:41 +0900 Subject: [PATCH] fix: Mypy type errors --- serde/compat.py | 95 +++++++++++++++++++++++++----------------------- serde/core.py | 3 +- serde/de.py | 3 +- serde/inspect.py | 3 +- serde/json.py | 4 +- serde/se.py | 21 ++++++----- setup.cfg | 1 + 7 files changed, 70 insertions(+), 60 deletions(-) diff --git a/serde/compat.py b/serde/compat.py index 199947d6..9c092827 100644 --- a/serde/compat.py +++ b/serde/compat.py @@ -14,9 +14,10 @@ import typing import uuid from dataclasses import is_dataclass -from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Optional, Set, Tuple, TypeVar, Union import typing_inspect +from typing_extensions import Type if sys.version_info[:2] == (3, 7): import typing_extensions @@ -52,22 +53,22 @@ def get_np_args(tp): else: - def get_np_origin(tp): + def get_np_origin(tp: Type[Any]) -> Optional[Any]: return None - def get_np_args(tp): + def get_np_args(tp: Any) -> Tuple[Any, ...]: return () except ImportError: - def get_np_origin(tp): + def get_np_origin(tp: Type[Any]) -> Optional[Any]: return None - def get_np_args(tp): + def get_np_args(tp: Any) -> Tuple[Any, ...]: return () -__all__: List = [] +__all__: List[str] = [] T = TypeVar('T') @@ -115,7 +116,7 @@ class SerdeSkip(Exception): """ -def get_origin(typ): +def get_origin(typ: Type[Any]) -> Optional[Any]: """ Provide `get_origin` that works in all python versions. """ @@ -125,7 +126,7 @@ def get_origin(typ): return typing_inspect.get_origin(typ) or get_np_origin(typ) -def get_args(typ): +def get_args(typ: Any) -> Tuple[Any, ...]: """ Provide `get_args` that works in all python versions. """ @@ -135,7 +136,7 @@ def get_args(typ): return typing_inspect.get_args(typ) or get_np_args(typ) -def typename(typ, with_typing_module: bool = False) -> str: +def typename(typ: Type[Any], with_typing_module: bool = False) -> str: """ >>> from typing import List, Dict, Set, Any >>> typename(int) @@ -216,7 +217,9 @@ def typename(typ, with_typing_module: bool = False) -> str: else: return f'{mod}Tuple' elif is_generic(typ): - return get_origin(typ).__name__ + origin = get_origin(typ) + assert origin is not None + return origin.__name__ elif is_literal(typ): args = type_args(typ) if not args: @@ -230,33 +233,33 @@ def typename(typ, with_typing_module: bool = False) -> str: if inner: return typename(typ.__supertype__) - name = getattr(typ, '_name', None) + name: Optional[str] = getattr(typ, '_name', None) if name: return name else: return typ.__name__ -def type_args(typ): +def type_args(typ: Any) -> Tuple[Type[Any], ...]: """ Wrapper to suppress type error for accessing private members. """ try: - args = typ.__args__ # type: ignore + args: Tuple[Type[Any, ...]] = typ.__args__ # type: ignore if args is None: - return [] + return () else: return args except AttributeError: return get_args(typ) -def union_args(typ: Union) -> Tuple: +def union_args(typ: Union) -> Tuple[Type[Any], ...]: if not is_union(typ): raise TypeError(f'{typ} is not Union') args = type_args(typ) if len(args) == 1: - return args[0] + return (args[0],) it = iter(args) types = [] for (i1, i2) in itertools.zip_longest(it, it): @@ -269,7 +272,7 @@ def union_args(typ: Union) -> Tuple: return tuple(types) -def dataclass_fields(cls: Type) -> Iterator: +def dataclass_fields(cls: Type[Any]) -> Iterator[dataclasses.Field]: raw_fields = dataclasses.fields(cls) try: @@ -454,7 +457,7 @@ def is_opt(typ) -> bool: return typ is Optional -def is_bare_opt(typ) -> bool: +def is_bare_opt(typ: Any) -> bool: """ Test if the type is `typing.Optional` without type args. >>> is_bare_opt(Optional[int]) @@ -467,7 +470,7 @@ def is_bare_opt(typ) -> bool: return not type_args(typ) and typ is Optional -def is_list(typ) -> bool: +def is_list(typ: Type[Any]) -> bool: """ Test if the type is `typing.List`. @@ -478,12 +481,12 @@ def is_list(typ) -> bool: True """ try: - return issubclass(get_origin(typ), list) + return issubclass(get_origin(typ), list) # type: ignore except TypeError: return typ in (List, list) -def is_bare_list(typ) -> bool: +def is_bare_list(typ: Type[Any]) -> bool: """ Test if the type is `typing.List` without type args. @@ -496,17 +499,17 @@ def is_bare_list(typ) -> bool: return typ in (List, list) -def is_tuple(typ) -> bool: +def is_tuple(typ: Type[Any]) -> bool: """ Test if the type is `typing.Tuple`. """ try: - return issubclass(get_origin(typ), tuple) + return issubclass(get_origin(typ), tuple) # type: ignore except TypeError: return typ in (Tuple, tuple) -def is_bare_tuple(typ) -> bool: +def is_bare_tuple(typ: Type[Any]) -> bool: """ Test if the type is `typing.Tuple` without type args. @@ -519,7 +522,7 @@ def is_bare_tuple(typ) -> bool: return typ in (Tuple, tuple) -def is_set(typ) -> bool: +def is_set(typ: Type[Any]) -> bool: """ Test if the type is `typing.Set`. @@ -530,12 +533,12 @@ def is_set(typ) -> bool: True """ try: - return issubclass(get_origin(typ), set) + return issubclass(get_origin(typ), set) # type: ignore except TypeError: return typ in (Set, set) -def is_bare_set(typ) -> bool: +def is_bare_set(typ: Type[Any]) -> bool: """ Test if the type is `typing.Set` without type args. @@ -548,7 +551,7 @@ def is_bare_set(typ) -> bool: return typ in (Set, set) -def is_dict(typ) -> bool: +def is_dict(typ: Type[Any]) -> bool: """ Test if the type is `typing.Dict`. @@ -559,12 +562,12 @@ def is_dict(typ) -> bool: True """ try: - return issubclass(get_origin(typ), dict) + return issubclass(get_origin(typ), dict) # type: ignore except TypeError: return typ in (Dict, dict) -def is_bare_dict(typ) -> bool: +def is_bare_dict(typ: Type[Any]) -> bool: """ Test if the type is `typing.Dict` without type args. @@ -577,7 +580,7 @@ def is_bare_dict(typ) -> bool: return typ in (Dict, dict) -def is_none(typ) -> bool: +def is_none(typ: Type[Any]) -> bool: """ >>> is_none(int) False @@ -592,7 +595,7 @@ def is_none(typ) -> bool: PRIMITIVES = [int, float, bool, str] -def is_enum(typ) -> bool: +def is_enum(typ: Type[Any]) -> bool: """ Test if the type is `enum.Enum`. """ @@ -602,7 +605,7 @@ def is_enum(typ) -> bool: return isinstance(typ, enum.Enum) -def is_primitive(typ) -> bool: +def is_primitive(typ: Type[Any]) -> bool: """ Test if the type is primitive. @@ -619,7 +622,7 @@ def is_primitive(typ) -> bool: return is_new_type_primitive(typ) -def is_new_type_primitive(typ) -> bool: +def is_new_type_primitive(typ: Type[Any]) -> bool: """ Test if the type is a NewType of primitives. """ @@ -630,7 +633,7 @@ def is_new_type_primitive(typ) -> bool: return any(isinstance(typ, ty) for ty in PRIMITIVES) -def is_generic(typ) -> bool: +def is_generic(typ: Type[Any]) -> bool: """ Test if the type is derived from `typing.Generic`. @@ -646,7 +649,7 @@ def is_generic(typ) -> bool: return origin is not None and Generic in getattr(origin, "__bases__", ()) -def is_literal(typ) -> bool: +def is_literal(typ: Type[Any]) -> bool: """ Test if the type is derived from `typing.Literal`. @@ -661,39 +664,39 @@ def is_literal(typ) -> bool: origin = get_origin(typ) if sys.version_info[:2] == (3, 7): return origin is typing_extensions.Literal - return origin is typing.Literal + return origin is not None and origin is typing.Literal -def is_any(typ) -> bool: +def is_any(typ: Type[Any]) -> bool: """ Test if the type is `typing.Any`. """ return typ is Any -def is_str_serializable(typ) -> bool: +def is_str_serializable(typ: Type[Any]) -> bool: """ Test if the type is serializable to `str`. """ return typ in StrSerializableTypes -def is_datetime(typ) -> bool: +def is_datetime(typ: Type[Any]) -> bool: """ Test if the type is any of the datetime types.. """ return typ in DateTimeTypes -def is_str_serializable_instance(obj) -> bool: +def is_str_serializable_instance(obj: Any) -> bool: return isinstance(obj, StrSerializableTypes) -def is_datetime_instance(obj) -> bool: +def is_datetime_instance(obj: Any) -> bool: return isinstance(obj, DateTimeTypes) -def find_generic_arg(cls, field) -> int: +def find_generic_arg(cls: Type[Any], field: TypeVar) -> int: """ Find a type in generic parameters. @@ -724,7 +727,7 @@ def find_generic_arg(cls, field) -> int: return -1 -def get_generic_arg(typ, index): +def get_generic_arg(typ: Any, index: int) -> Any: """ Get generic type argument by index. @@ -746,7 +749,7 @@ def get_generic_arg(typ, index): return args[index] -def has_default(field) -> bool: +def has_default(field: dataclasses.Field) -> bool: """ Test if the field has default value. @@ -762,7 +765,7 @@ def has_default(field) -> bool: return not isinstance(field.default, dataclasses._MISSING_TYPE) -def has_default_factory(field) -> bool: +def has_default_factory(field: dataclasses.Field) -> bool: """ Test if the field has default factory. diff --git a/serde/core.py b/serde/core.py index 53eb0cd7..6253f8ef 100644 --- a/serde/core.py +++ b/serde/core.py @@ -7,10 +7,11 @@ import logging import re from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, TypeVar, Union import casefy import jinja2 +from typing_extensions import Type from .compat import ( SerdeError, diff --git a/serde/de.py b/serde/de.py index a88a2e70..fada9c76 100644 --- a/serde/de.py +++ b/serde/de.py @@ -8,9 +8,10 @@ import functools import typing from dataclasses import dataclass, is_dataclass -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar import jinja2 +from typing_extensions import Type from .compat import ( Literal, diff --git a/serde/inspect.py b/serde/inspect.py index 5d6137a1..83f114b9 100644 --- a/serde/inspect.py +++ b/serde/inspect.py @@ -16,7 +16,8 @@ import logging import os import sys -from typing import Type + +from typing_extensions import Type from .core import SERDE_SCOPE, SerdeScope, init, logger diff --git a/serde/json.py b/serde/json.py index 348734fb..ceddeb72 100644 --- a/serde/json.py +++ b/serde/json.py @@ -1,7 +1,9 @@ """ Serialize and Deserialize in JSON format. """ -from typing import Any, Type, Union +from typing import Any, Union + +from typing_extensions import Type from .compat import T from .de import Deserializer, from_dict diff --git a/serde/se.py b/serde/se.py index 11aad7f3..a800f5e2 100644 --- a/serde/se.py +++ b/serde/se.py @@ -12,6 +12,7 @@ from typing import Any, Callable, Dict, Iterator, List, Optional, Type, TypeVar import jinja2 +from typing_extensions import Type from .compat import ( Literal, @@ -78,17 +79,17 @@ __all__ = ["serialize", "is_serializable", "to_dict", "to_tuple"] # Interface of Custom serialize function. -SerializeFunc = Callable[[Type, Any], Any] +SerializeFunc = Callable[[Type[Any], Any], Any] -def default_serializer(_cls: Type, obj): +def default_serializer(_cls: Type[Any], obj): """ Marker function to tell serde to use the default serializer. It's used when custom serializer is specified at the class but you want to override a field with the default serializer. """ -def serde_custom_class_serializer(cls: Type, obj, custom: SerializeFunc, default: Callable): +def serde_custom_class_serializer(cls: Type[Any], obj: Any, custom: SerializeFunc, default: Callable): try: return custom(cls, obj) except SerdeSkip: @@ -198,7 +199,7 @@ def serialize( '{"i":10,"dt":"01/01/21"}' """ - def wrap(cls: Type): + def wrap(cls: Type[Any]): tagging.check() # If no `dataclass` found in the class, dataclassify it automatically. @@ -290,7 +291,7 @@ def is_serializable(instance_or_class: Any) -> bool: return hasattr(instance_or_class, SERDE_SCOPE) -def to_obj(o, named: bool, reuse_instances: bool, convert_sets: bool, c: Type = None): +def to_obj(o, named: bool, reuse_instances: bool, convert_sets: bool, c: Optional[Type[Any]] = None): try: thisfunc = functools.partial( to_obj, @@ -358,7 +359,7 @@ def to_tuple(o, reuse_instances: bool = ..., convert_sets: bool = ...) -> Any: return to_obj(o, named=False, reuse_instances=reuse_instances, convert_sets=convert_sets) -def asdict(v): +def asdict(v: Any) -> Dict[str, Any]: """ Serialize object into dictionary. """ @@ -414,7 +415,7 @@ def __getitem__(self, n) -> "SeField": return SeField(typ, name=None) -def sefields(cls: Type) -> Iterator[SeField]: +def sefields(cls: Type[Any]) -> Iterator[SeField]: """ Iterate fields for serialization. """ @@ -424,7 +425,7 @@ def sefields(cls: Type) -> Iterator[SeField]: yield f -def render_to_tuple(cls: Type, custom: Optional[SerializeFunc] = None, type_check: TypeCheck = NoCheck) -> str: +def render_to_tuple(cls: Type[Any], custom: Optional[SerializeFunc] = None, type_check: TypeCheck = NoCheck) -> str: template = """ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}}, convert_sets = {{serde_scope.convert_sets_default}}): @@ -458,7 +459,7 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}}, def render_to_dict( - cls: Type, case: Optional[str] = None, custom: Optional[SerializeFunc] = None, type_check: TypeCheck = NoCheck + cls: Type[Any], case: Optional[str] = None, custom: Optional[SerializeFunc] = None, type_check: TypeCheck = NoCheck ) -> str: template = """ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}}, @@ -500,7 +501,7 @@ def {{func}}(obj, reuse_instances = {{serde_scope.reuse_instances_default}}, ) -def render_union_func(cls: Type, union_args: List[Type], tagging: Tagging = DefaultTagging) -> str: +def render_union_func(cls: Type[Any], union_args: List[Type[Any]], tagging: Tagging = DefaultTagging) -> str: """ Render function that serializes a field with union type. """ diff --git a/setup.cfg b/setup.cfg index 3e68956c..42c3def7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,3 +10,4 @@ max-complexity = 30 line_length = 120 [mypy] ignore_missing_imports = True +strict = True