diff --git a/typing_validation/validation.py b/typing_validation/validation.py index ed09b63..53ce88f 100644 --- a/typing_validation/validation.py +++ b/typing_validation/validation.py @@ -515,9 +515,6 @@ def _extract_dtypes(t: Any) -> typing.Sequence[Any]: dtype for member in t.__args__ for dtype in _extract_dtypes(member) ] import numpy as np # pylint: disable = import-outside-toplevel - - if isinstance(t, type) and issubclass(t, np.generic): - return [t] if hasattr(t, "__origin__"): t_origin = t.__origin__ if t_origin in { @@ -532,6 +529,8 @@ def _extract_dtypes(t: Any) -> typing.Sequence[Any]: if t == t_origin[Any]: return [t_origin] # TODO: add broader support for np.NBitBase subtypes + if isinstance(t, type) and issubclass(t, np.generic): + return [t] raise TypeError()