Skip to content

Commit

Permalink
Closes #14
Browse files Browse the repository at this point in the history
  • Loading branch information
sg495 committed Jan 30, 2024
1 parent 04cb356 commit 3d29107
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 11 deletions.
4 changes: 4 additions & 0 deletions test/test_00_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,10 @@ def test_numpy_array() -> None:
]
]
])
validate(val, npt.NDArray[np.number[typing.Any]])
validate(val, npt.NDArray[np.integer[typing.Any]])
validate(val, npt.NDArray[np.unsignedinteger[typing.Any]])
validate(val, npt.NDArray[np.generic])

def test_numpy_array_error() -> None:
# pylint: disable = import-outside-toplevel
Expand Down
49 changes: 39 additions & 10 deletions typing_validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def _type_alias_error(t_alias: str, cause: TypeError) -> TypeError:
validation_failure._t = t_alias
return cause


def _numpy_dtype_error(val: Any, t: Any) -> TypeError:
"""
Type error arising from ``val`` not being an instance of NumPy array
Expand All @@ -273,6 +274,7 @@ def _numpy_dtype_error(val: Any, t: Any) -> TypeError:
setattr(error, "validation_failure", validation_failure)
return error


def _missing_args_msg(t: Any) -> str:
"""Error message for missing :attr:`__args__` attribute on a type ``t``."""
return (
Expand Down Expand Up @@ -484,6 +486,7 @@ def _validate_typed_dict(val: Any, t: type) -> None:
except TypeError as e:
raise _key_type_error(val, t, e, key=k) from None


def _validate_user_class(val: Any, t: Any) -> None:
assert hasattr(t, "__args__"), _missing_args_msg(t)
assert isinstance(
Expand All @@ -498,27 +501,50 @@ def _validate_user_class(val: Any, t: Any) -> None:
_validate_type(val, t.__origin__)
# Generic type arguments cannot be validated


def _extract_dtypes(t: Any) -> typing.Sequence[Any]:
if t is Any:
return [Any]
if (UnionType is not None and isinstance(t, UnionType)
or hasattr(t, "__origin__") and t.__origin__ is Union):
if (
UnionType is not None
and isinstance(t, UnionType)
or hasattr(t, "__origin__")
and t.__origin__ is Union
):
return [
dtype
for member in t.__args__
for dtype in _extract_dtypes(member)
dtype for member in t.__args__ for dtype in _extract_dtypes(member)
]
import numpy as np # pylint: disable = import-outside-toplevel
import numpy as np # pylint: disable = import-outside-toplevel

if isinstance(t, type) and issubclass(t, np.generic):
return [t]
if hasattr(t, "__origin__"):
t_origin = t.__origin__
if t_origin in {
np.number,
np.inexact,
np.floating,
np.complexfloating,
np.integer,
np.signedinteger,
np.unsignedinteger,
}:
if t == t_origin[Any]:
return [t_origin]
# TODO: add broader support for np.NBitBase subtypes
raise TypeError()


def _validate_numpy_array(val: Any, t: Any) -> None:
assert hasattr(t, "__args__"), _missing_args_msg(t)
assert len(t.__args__) == 2, _wrong_args_num_msg(t, 2)
dtype_t_container = t.__args__[1]
assert hasattr(dtype_t_container, "__args__"), _missing_args_msg(dtype_t_container)
assert len(dtype_t_container.__args__) == 1, _wrong_args_num_msg(dtype_t_container, 1)
assert hasattr(dtype_t_container, "__args__"), _missing_args_msg(
dtype_t_container
)
assert len(dtype_t_container.__args__) == 1, _wrong_args_num_msg(
dtype_t_container, 1
)
dtype_t = dtype_t_container.__args__[0]
try:
dtypes = _extract_dtypes(dtype_t)
Expand All @@ -535,13 +561,15 @@ def _validate_numpy_array(val: Any, t: Any) -> None:
for arg in t.__args__:
validate(val, arg)
return
import numpy as np # pylint: disable = import-outside-toplevel
import numpy as np # pylint: disable = import-outside-toplevel

assert isinstance(val, np.ndarray)
val_dtype = val.dtype
if any(dtype is Any or np.issubdtype(val_dtype, dtype) for dtype in dtypes):
return
raise _numpy_dtype_error(val, t)


# def _validate_callable(val: Any, t: Any) -> None:
# """
# Callable validation
Expand Down Expand Up @@ -680,7 +708,8 @@ def validate(val: Any, t: Any) -> Literal[True]:
return True
elif isinstance(t.__origin__, type):
try:
import numpy as np # pylint: disable = import-outside-toplevel
import numpy as np # pylint: disable = import-outside-toplevel

if issubclass(t.__origin__, np.ndarray):
_validate_numpy_array(val, t)
return True
Expand Down
3 changes: 2 additions & 1 deletion typing_validation/validation_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ def __new__(
*,
type_aliases: Optional[Mapping[str, Any]] = None,
) -> Self:
import numpy as np # pylint: disable = import-outside-toplevel
import numpy as np # pylint: disable = import-outside-toplevel

assert isinstance(val, np.ndarray)
instance = super().__new__(cls, val, t, type_aliases=type_aliases)
return instance
Expand Down

0 comments on commit 3d29107

Please sign in to comment.