Skip to content

Commit

Permalink
fix: Mypy type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
yukinarit committed Sep 3, 2022
1 parent 88be343 commit 3ac1510
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 60 deletions.
95 changes: 49 additions & 46 deletions serde/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import typing
import uuid
from dataclasses import is_dataclass
from typing import Any, ClassVar, Dict, Generic, Iterator, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, Iterator, List, Optional, Set, Tuple, TypeVar, Union

import typing_inspect
from typing_extensions import Type

if sys.version_info[:2] == (3, 7):
import typing_extensions
Expand Down Expand Up @@ -52,22 +53,22 @@ def get_np_args(tp):

else:

def get_np_origin(tp):
def get_np_origin(tp: Type[Any]) -> Optional[Any]:
return None

def get_np_args(tp):
def get_np_args(tp: Any) -> Tuple[Any, ...]:
return ()

except ImportError:

def get_np_origin(tp):
def get_np_origin(tp: Type[Any]) -> Optional[Any]:
return None

def get_np_args(tp):
def get_np_args(tp: Any) -> Tuple[Any, ...]:
return ()


__all__: List = []
__all__: List[str] = []

T = TypeVar('T')

Expand Down Expand Up @@ -115,7 +116,7 @@ class SerdeSkip(Exception):
"""


def get_origin(typ):
def get_origin(typ: Type[Any]) -> Optional[Any]:
"""
Provide `get_origin` that works in all python versions.
"""
Expand All @@ -125,7 +126,7 @@ def get_origin(typ):
return typing_inspect.get_origin(typ) or get_np_origin(typ)


def get_args(typ):
def get_args(typ: Any) -> Tuple[Any, ...]:
"""
Provide `get_args` that works in all python versions.
"""
Expand All @@ -135,7 +136,7 @@ def get_args(typ):
return typing_inspect.get_args(typ) or get_np_args(typ)


def typename(typ, with_typing_module: bool = False) -> str:
def typename(typ: Type[Any], with_typing_module: bool = False) -> str:
"""
>>> from typing import List, Dict, Set, Any
>>> typename(int)
Expand Down Expand Up @@ -216,7 +217,9 @@ def typename(typ, with_typing_module: bool = False) -> str:
else:
return f'{mod}Tuple'
elif is_generic(typ):
return get_origin(typ).__name__
origin = get_origin(typ)
assert origin is not None
return origin.__name__
elif is_literal(typ):
args = type_args(typ)
if not args:
Expand All @@ -230,33 +233,33 @@ def typename(typ, with_typing_module: bool = False) -> str:
if inner:
return typename(typ.__supertype__)

name = getattr(typ, '_name', None)
name: Optional[str] = getattr(typ, '_name', None)
if name:
return name
else:
return typ.__name__


def type_args(typ):
def type_args(typ: Any) -> Tuple[Type[Any], ...]:
"""
Wrapper to suppress type error for accessing private members.
"""
try:
args = typ.__args__ # type: ignore
args: Tuple[Type[Any, ...]] = typ.__args__ # type: ignore
if args is None:
return []
return ()
else:
return args
except AttributeError:
return get_args(typ)


def union_args(typ: Union) -> Tuple:
def union_args(typ: Union) -> Tuple[Type[Any], ...]:
if not is_union(typ):
raise TypeError(f'{typ} is not Union')
args = type_args(typ)
if len(args) == 1:
return args[0]
return (args[0],)
it = iter(args)
types = []
for (i1, i2) in itertools.zip_longest(it, it):
Expand All @@ -269,7 +272,7 @@ def union_args(typ: Union) -> Tuple:
return tuple(types)


def dataclass_fields(cls: Type) -> Iterator:
def dataclass_fields(cls: Type[Any]) -> Iterator[dataclasses.Field]:
raw_fields = dataclasses.fields(cls)

try:
Expand Down Expand Up @@ -454,7 +457,7 @@ def is_opt(typ) -> bool:
return typ is Optional


def is_bare_opt(typ) -> bool:
def is_bare_opt(typ: Any) -> bool:
"""
Test if the type is `typing.Optional` without type args.
>>> is_bare_opt(Optional[int])
Expand All @@ -467,7 +470,7 @@ def is_bare_opt(typ) -> bool:
return not type_args(typ) and typ is Optional


def is_list(typ) -> bool:
def is_list(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.List`.
Expand All @@ -478,12 +481,12 @@ def is_list(typ) -> bool:
True
"""
try:
return issubclass(get_origin(typ), list)
return issubclass(get_origin(typ), list) # type: ignore
except TypeError:
return typ in (List, list)


def is_bare_list(typ) -> bool:
def is_bare_list(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.List` without type args.
Expand All @@ -496,17 +499,17 @@ def is_bare_list(typ) -> bool:
return typ in (List, list)


def is_tuple(typ) -> bool:
def is_tuple(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Tuple`.
"""
try:
return issubclass(get_origin(typ), tuple)
return issubclass(get_origin(typ), tuple) # type: ignore
except TypeError:
return typ in (Tuple, tuple)


def is_bare_tuple(typ) -> bool:
def is_bare_tuple(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Tuple` without type args.
Expand All @@ -519,7 +522,7 @@ def is_bare_tuple(typ) -> bool:
return typ in (Tuple, tuple)


def is_set(typ) -> bool:
def is_set(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Set`.
Expand All @@ -530,12 +533,12 @@ def is_set(typ) -> bool:
True
"""
try:
return issubclass(get_origin(typ), set)
return issubclass(get_origin(typ), set) # type: ignore
except TypeError:
return typ in (Set, set)


def is_bare_set(typ) -> bool:
def is_bare_set(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Set` without type args.
Expand All @@ -548,7 +551,7 @@ def is_bare_set(typ) -> bool:
return typ in (Set, set)


def is_dict(typ) -> bool:
def is_dict(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Dict`.
Expand All @@ -559,12 +562,12 @@ def is_dict(typ) -> bool:
True
"""
try:
return issubclass(get_origin(typ), dict)
return issubclass(get_origin(typ), dict) # type: ignore
except TypeError:
return typ in (Dict, dict)


def is_bare_dict(typ) -> bool:
def is_bare_dict(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Dict` without type args.
Expand All @@ -577,7 +580,7 @@ def is_bare_dict(typ) -> bool:
return typ in (Dict, dict)


def is_none(typ) -> bool:
def is_none(typ: Type[Any]) -> bool:
"""
>>> is_none(int)
False
Expand All @@ -592,7 +595,7 @@ def is_none(typ) -> bool:
PRIMITIVES = [int, float, bool, str]


def is_enum(typ) -> bool:
def is_enum(typ: Type[Any]) -> bool:
"""
Test if the type is `enum.Enum`.
"""
Expand All @@ -602,7 +605,7 @@ def is_enum(typ) -> bool:
return isinstance(typ, enum.Enum)


def is_primitive(typ) -> bool:
def is_primitive(typ: Type[Any]) -> bool:
"""
Test if the type is primitive.
Expand All @@ -619,7 +622,7 @@ def is_primitive(typ) -> bool:
return is_new_type_primitive(typ)


def is_new_type_primitive(typ) -> bool:
def is_new_type_primitive(typ: Type[Any]) -> bool:
"""
Test if the type is a NewType of primitives.
"""
Expand All @@ -630,7 +633,7 @@ def is_new_type_primitive(typ) -> bool:
return any(isinstance(typ, ty) for ty in PRIMITIVES)


def is_generic(typ) -> bool:
def is_generic(typ: Type[Any]) -> bool:
"""
Test if the type is derived from `typing.Generic`.
Expand All @@ -646,7 +649,7 @@ def is_generic(typ) -> bool:
return origin is not None and Generic in getattr(origin, "__bases__", ())


def is_literal(typ) -> bool:
def is_literal(typ: Type[Any]) -> bool:
"""
Test if the type is derived from `typing.Literal`.
Expand All @@ -661,39 +664,39 @@ def is_literal(typ) -> bool:
origin = get_origin(typ)
if sys.version_info[:2] == (3, 7):
return origin is typing_extensions.Literal
return origin is typing.Literal
return origin is not None and origin is typing.Literal


def is_any(typ) -> bool:
def is_any(typ: Type[Any]) -> bool:
"""
Test if the type is `typing.Any`.
"""
return typ is Any


def is_str_serializable(typ) -> bool:
def is_str_serializable(typ: Type[Any]) -> bool:
"""
Test if the type is serializable to `str`.
"""
return typ in StrSerializableTypes


def is_datetime(typ) -> bool:
def is_datetime(typ: Type[Any]) -> bool:
"""
Test if the type is any of the datetime types..
"""
return typ in DateTimeTypes


def is_str_serializable_instance(obj) -> bool:
def is_str_serializable_instance(obj: Any) -> bool:
return isinstance(obj, StrSerializableTypes)


def is_datetime_instance(obj) -> bool:
def is_datetime_instance(obj: Any) -> bool:
return isinstance(obj, DateTimeTypes)


def find_generic_arg(cls, field) -> int:
def find_generic_arg(cls: Type[Any], field: TypeVar) -> int:
"""
Find a type in generic parameters.
Expand Down Expand Up @@ -724,7 +727,7 @@ def find_generic_arg(cls, field) -> int:
return -1


def get_generic_arg(typ, index):
def get_generic_arg(typ: Any, index: int) -> Any:
"""
Get generic type argument by index.
Expand All @@ -746,7 +749,7 @@ def get_generic_arg(typ, index):
return args[index]


def has_default(field) -> bool:
def has_default(field: dataclasses.Field) -> bool:
"""
Test if the field has default value.
Expand All @@ -762,7 +765,7 @@ def has_default(field) -> bool:
return not isinstance(field.default, dataclasses._MISSING_TYPE)


def has_default_factory(field) -> bool:
def has_default_factory(field: dataclasses.Field) -> bool:
"""
Test if the field has default factory.
Expand Down
3 changes: 2 additions & 1 deletion serde/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import logging
import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, TypeVar, Union

import casefy
import jinja2
from typing_extensions import Type

from .compat import (
SerdeError,
Expand Down
3 changes: 2 additions & 1 deletion serde/de.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import functools
import typing
from dataclasses import dataclass, is_dataclass
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar

import jinja2
from typing_extensions import Type

from .compat import (
Literal,
Expand Down
3 changes: 2 additions & 1 deletion serde/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import logging
import os
import sys
from typing import Type

from typing_extensions import Type

from .core import SERDE_SCOPE, SerdeScope, init, logger

Expand Down
Loading

0 comments on commit 3ac1510

Please sign in to comment.