diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 0ac68a830..474e14397 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -291,17 +291,17 @@ jobs: $WEAVIATE_126 ] steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Download build artifact to append to release uses: actions/download-artifact@v4 with: name: weaviate-python-client-wheel - run: | pip install weaviate_client-*.whl - pip install pytest pytest-benchmark pytest-profiling grpcio grpcio-tools pytest-xdist - - name: Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 + pip install -r requirements-devel.txt # install test dependencies - name: free space run: sudo rm -rf /usr/local/lib/android - run: rm -r weaviate diff --git a/integration/test_collection.py b/integration/test_collection.py index 12f155865..5f705a302 100644 --- a/integration/test_collection.py +++ b/integration/test_collection.py @@ -7,6 +7,7 @@ import pytest +import weaviate.classes as wvc from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED from weaviate.collections.classes.batch import ErrorObject @@ -51,8 +52,6 @@ ) from weaviate.types import UUID, UUIDS -import weaviate.classes as wvc - UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9") UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a") UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75") @@ -863,107 +862,6 @@ def test_query_properties(collection_factory: CollectionFactory) -> None: assert len(objects) == 0 -def test_near_vector(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - full_objects = collection.query.near_vector( - banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True) - ).objects - assert len(full_objects) == 4 - - objects_distance = collection.query.near_vector( - banana.vector["default"], distance=full_objects[2].metadata.distance - ).objects - assert len(objects_distance) == 3 - - objects_distance = collection.query.near_vector( - banana.vector["default"], certainty=full_objects[2].metadata.certainty - ).objects - assert len(objects_distance) == 3 - - -def test_near_vector_limit(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - objs = collection.query.near_vector(banana.vector["default"], limit=2).objects - assert len(objs) == 2 - - -def test_near_vector_offset(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[Property(name="Name", data_type=DataType.TEXT)], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana = collection.data.insert({"Name": "Banana"}) - uuid_fruit = collection.data.insert({"Name": "Fruit"}) - collection.data.insert({"Name": "car"}) - collection.data.insert({"Name": "Mountain"}) - - banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) - - objs = collection.query.near_vector(banana.vector["default"], offset=1).objects - assert len(objs) == 3 - assert objs[0].uuid == uuid_fruit - - -def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None: - collection = collection_factory( - properties=[ - Property(name="Name", data_type=DataType.TEXT), - Property(name="Count", data_type=DataType.INT), - ], - vectorizer_config=Configure.Vectorizer.text2vec_contextionary( - vectorize_collection_name=False - ), - ) - uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51}) - collection.data.insert({"Name": "Banana", "Count": 72}) - collection.data.insert({"Name": "car", "Count": 12}) - collection.data.insert({"Name": "Mountain", "Count": 1}) - - banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True) - - ret = collection.query.near_vector( - banana1.vector["default"], - group_by=GroupBy( - prop="name", - number_of_groups=4, - objects_per_group=10, - ), - return_metadata=MetadataQuery(distance=True, certainty=True), - ) - - assert len(ret.objects) == 4 - assert ret.objects[0].belongs_to_group == "Banana" - assert ret.objects[1].belongs_to_group == "Banana" - assert ret.objects[2].belongs_to_group == "car" - assert ret.objects[3].belongs_to_group == "Mountain" - - def test_near_object(collection_factory: CollectionFactory) -> None: collection = collection_factory( properties=[Property(name="Name", data_type=DataType.TEXT)], diff --git a/integration/test_collection_near_vector.py b/integration/test_collection_near_vector.py new file mode 100644 index 000000000..1b866c12e --- /dev/null +++ b/integration/test_collection_near_vector.py @@ -0,0 +1,177 @@ +import uuid +from typing import Any + +import numpy as np +import pandas as pd +import polars as pl +import pytest + +from integration.conftest import CollectionFactory +from weaviate.collections.classes.config import ( + Configure, + DataType, + Property, +) +from weaviate.collections.classes.grpc import ( + GroupBy, + MetadataQuery, +) + +UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9") +UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a") +UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75") + + +def test_near_vector(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + full_objects = collection.query.near_vector( + banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True) + ).objects + assert len(full_objects) == 4 + + objects_distance = collection.query.near_vector( + banana.vector["default"], distance=full_objects[2].metadata.distance + ).objects + assert len(objects_distance) == 3 + + objects_distance = collection.query.near_vector( + banana.vector["default"], certainty=full_objects[2].metadata.certainty + ).objects + assert len(objects_distance) == 3 + + +def test_near_vector_limit(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + objs = collection.query.near_vector(banana.vector["default"], limit=2).objects + assert len(objs) == 2 + + +def test_near_vector_offset(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[Property(name="Name", data_type=DataType.TEXT)], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana = collection.data.insert({"Name": "Banana"}) + uuid_fruit = collection.data.insert({"Name": "Fruit"}) + collection.data.insert({"Name": "car"}) + collection.data.insert({"Name": "Mountain"}) + + banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True) + + objs = collection.query.near_vector(banana.vector["default"], offset=1).objects + assert len(objs) == 3 + assert objs[0].uuid == uuid_fruit + + +def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None: + collection = collection_factory( + properties=[ + Property(name="Name", data_type=DataType.TEXT), + Property(name="Count", data_type=DataType.INT), + ], + vectorizer_config=Configure.Vectorizer.text2vec_contextionary( + vectorize_collection_name=False + ), + ) + uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51}) + collection.data.insert({"Name": "Banana", "Count": 72}) + collection.data.insert({"Name": "car", "Count": 12}) + collection.data.insert({"Name": "Mountain", "Count": 1}) + + banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True) + + ret = collection.query.near_vector( + banana1.vector["default"], + group_by=GroupBy( + prop="name", + number_of_groups=4, + objects_per_group=10, + ), + return_metadata=MetadataQuery(distance=True, certainty=True), + ) + + assert len(ret.objects) == 4 + assert ret.objects[0].belongs_to_group == "Banana" + assert ret.objects[1].belongs_to_group == "Banana" + assert ret.objects[2].belongs_to_group == "car" + assert ret.objects[3].belongs_to_group == "Mountain" + + +@pytest.mark.parametrize( + "near_vector", [[1, 0], [1.0, 0.0], np.array([1, 0]), pl.Series([1, 0]), pd.Series([1, 0])] +) +def test_near_vector_with_other_input( + collection_factory: CollectionFactory, near_vector: Any +) -> None: + collection = collection_factory(vectorizer_config=Configure.Vectorizer.none()) + + uuid1 = collection.data.insert({}, vector=[1, 0]) + collection.data.insert({}, vector=[0, 1]) + + ret = collection.query.near_vector( + near_vector, + distance=0.1, + ) + assert len(ret.objects) == 1 + assert ret.objects[0].uuid == uuid1 + + +@pytest.mark.parametrize( + "near_vector", + [ + {"first": [1, 0], "second": [1, 0, 0]}, + {"first": np.array([1, 0]), "second": [1, 0, 0]}, + {"first": pl.Series([1, 0]), "second": [1, 0, 0]}, + {"first": pd.Series([1, 0]), "second": [1, 0, 0]}, + [np.array([1, 0]), [1, 0, 0]], + [pl.Series([1, 0]), [1, 0, 0]], + [pd.Series([1, 0]), [1, 0, 0]], + {"first": [1.0, 0.0], "second": [1.0, 0.0, 0.0]}, + ], +) +def test_near_vector_with_named_vector_other_input( + collection_factory: CollectionFactory, near_vector: Any +) -> None: + dummy = collection_factory("dummy") + if dummy._connection._weaviate_version.is_lower_than(1, 26, 0): + pytest.skip("Named vectors are supported in versions higher than 1.26.0") + + collection = collection_factory( + vectorizer_config=[ + Configure.NamedVectors.none("first"), + Configure.NamedVectors.none("second"), + ] + ) + + uuid1 = collection.data.insert({}, vector={"first": [1, 0], "second": [1, 0, 0]}) + collection.data.insert({}, vector={"first": [0, 1], "second": [0, 0, 1]}) + + ret = collection.query.near_vector(near_vector, distance=0.1, target_vector=["first", "second"]) + assert len(ret.objects) == 1 + assert ret.objects[0].uuid == uuid1 diff --git a/requirements-devel.txt b/requirements-devel.txt index 43c15edff..bcd110582 100644 --- a/requirements-devel.txt +++ b/requirements-devel.txt @@ -29,6 +29,7 @@ pytest-httpserver>=1.0.8 numpy>=1.24.4,<2.0.0 pandas>=2.0.3,<3.0.0 +pandas-stubs>=2.0.3,<3.0.0 polars>=0.20.26,<0.21.0 mypy>=1.9.0,<2.0.0 diff --git a/test/collection/test_validator.py b/test/collection/test_validator.py new file mode 100644 index 000000000..66cfaafef --- /dev/null +++ b/test/collection/test_validator.py @@ -0,0 +1,38 @@ +from typing import Any, List + +import numpy as np +import pandas as pd +import polars as pl +import pytest + +from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.validator import _validate_input, _ValidateArgument, _ExtraTypes + + +@pytest.mark.parametrize( + "inputs,expected,error", + [ + (1, [int], False), + (1.0, [int], True), + ([1, 1], [List], False), + (np.array([1, 2, 3]), [_ExtraTypes.NUMPY], False), + (np.array([1, 2, 3]), [_ExtraTypes.NUMPY, List], False), + (np.array([1, 2, 3]), [List], True), + ([1, 1], [List, _ExtraTypes.NUMPY], False), + (pd.array([1, 1]), [_ExtraTypes.PANDAS, List], False), + (pd.Series([1, 1]), [_ExtraTypes.PANDAS, List], False), + (pl.Series([1, 1]), [_ExtraTypes.POLARS, List], False), + ( + pl.Series([1, 1]), + [_ExtraTypes.POLARS, _ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List], + False, + ), + (pl.Series([1, 1]), [_ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List], True), + ], +) +def test_validator(inputs: Any, expected: List[Any], error: bool) -> None: + if error: + with pytest.raises(WeaviateInvalidInputError): + _validate_input(_ValidateArgument(expected=expected, name="test", value=inputs)) + else: + _validate_input(_ValidateArgument(expected=expected, name="test", value=inputs)) diff --git a/weaviate/collections/classes/grpc.py b/weaviate/collections/classes/grpc.py index ad38b3b44..9c1f7c7c9 100644 --- a/weaviate/collections/classes/grpc.py +++ b/weaviate/collections/classes/grpc.py @@ -6,8 +6,8 @@ from weaviate.collections.classes.types import _WeaviateInput from weaviate.proto.v1 import search_get_pb2 +from weaviate.str_enum import BaseEnum from weaviate.types import INCLUDE_VECTOR, UUID -from weaviate.util import BaseEnum class HybridFusion(str, BaseEnum): diff --git a/weaviate/collections/grpc/query.py b/weaviate/collections/grpc/query.py index 29175345c..08314170d 100644 --- a/weaviate/collections/grpc/query.py +++ b/weaviate/collections/grpc/query.py @@ -55,8 +55,8 @@ ) from weaviate.proto.v1 import search_get_pb2 from weaviate.types import NUMBER, UUID -from weaviate.util import _get_vector_v4 -from weaviate.validator import _ValidateArgument, _validate_input +from weaviate.util import _get_vector_v4, _is_1d_vector +from weaviate.validator import _ValidateArgument, _validate_input, _ExtraTypes # Can be found in the google.protobuf.internal.well_known_types.pyi stub file but is defined explicitly here for clarity. _PyValue: TypeAlias = Union[ @@ -330,7 +330,18 @@ def near_vector( if self._validate_arguments: _validate_input( [ - _ValidateArgument([List, Dict], "near_vector", near_vector), + _ValidateArgument( + [ + List, + Dict, + _ExtraTypes.PANDAS, + _ExtraTypes.POLARS, + _ExtraTypes.NUMPY, + _ExtraTypes.TF, + ], + "near_vector", + near_vector, + ), _ValidateArgument( [str, None, List, _MultiTargetVectorJoin], "target_vector", target_vector ), @@ -340,7 +351,13 @@ def near_vector( certainty, distance = self.__parse_near_options(certainty, distance) targets, target_vectors = self.__target_vector_to_grpc(target_vector) - + invalid_nv_exception = WeaviateInvalidInputError( + f"""near vector argument can be: + - a list of numbers + - a list of lists of numbers for multi target search + - a dictionary with target names as keys and lists of numbers as values + received: {near_vector}""" + ) if isinstance(near_vector, dict): if targets is None or len(targets.target_vectors) != len(near_vector): raise WeaviateInvalidInputError( @@ -349,17 +366,15 @@ def near_vector( vector_per_target: Dict[str, bytes] = {} for key, value in near_vector.items(): + nv = _get_vector_v4(value) + if ( - not isinstance(value, list) - or len(value) == 0 - or not isinstance(value[0], get_args(NUMBER)) + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) ): - raise WeaviateQueryError( - "The value of the near_vector dict must be a lists of numbers", - "GRPC", - ) + raise invalid_nv_exception - nv = _get_vector_v4(value) vector_per_target[key] = struct.pack("{}f".format(len(nv)), *nv) near_vector_grpc = search_get_pb2.NearVector( certainty=certainty, @@ -369,16 +384,13 @@ def near_vector( vector_per_target=vector_per_target, ) else: - if not isinstance(near_vector, list) or len(near_vector) == 0: - raise WeaviateInvalidInputError( - """near vector argument can be: - - a list of numbers - - a list of lists of numbers for multi target search - - a dictionary with target names as keys and lists of numbers as values""" - ) + if len(near_vector) == 0: + raise invalid_nv_exception - if isinstance(near_vector[0], get_args(NUMBER)): + if _is_1d_vector(near_vector): near_vector = _get_vector_v4(near_vector) + if not isinstance(near_vector, list): + raise invalid_nv_exception near_vector_grpc = search_get_pb2.NearVector( certainty=certainty, distance=distance, @@ -393,15 +405,13 @@ def near_vector( "The number of target vectors must be equal to the number of vectors." ) for i, vector in enumerate(near_vector): + nv = _get_vector_v4(vector) if ( - not isinstance(vector, list) - or len(vector) == 0 - or not isinstance(vector[0], get_args(NUMBER)) + not isinstance(nv, list) + or len(nv) == 0 + or not isinstance(nv[0], get_args(NUMBER)) ): - raise WeaviateInvalidInputError( - "The value of the near_vector entry must be a lists of numbers" - ) - nv = _get_vector_v4(vector) + raise invalid_nv_exception vector_per_target_tmp[targets.target_vectors[i]] = struct.pack( "{}f".format(len(nv)), *nv ) diff --git a/weaviate/gql/get.py b/weaviate/gql/get.py index 47060ef85..368d2cbde 100644 --- a/weaviate/gql/get.py +++ b/weaviate/gql/get.py @@ -7,6 +7,8 @@ from json import dumps from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +import grpc # type: ignore + from weaviate import util from weaviate.connect import Connection from weaviate.data.replication import ConsistencyLevel @@ -28,18 +30,16 @@ MediaType, Sort, ) +from weaviate.proto.v1 import search_get_pb2 +from weaviate.str_enum import BaseEnum +from weaviate.types import UUID from weaviate.util import ( image_encoder_b64, _capitalize_first_letter, get_valid_uuid, file_encoder_b64, - BaseEnum, ) from weaviate.warnings import _Warnings -from weaviate.types import UUID - -from weaviate.proto.v1 import search_get_pb2 -import grpc # type: ignore @dataclass diff --git a/weaviate/str_enum.py b/weaviate/str_enum.py new file mode 100644 index 000000000..78e7f6c06 --- /dev/null +++ b/weaviate/str_enum.py @@ -0,0 +1,19 @@ +# MetaEnum and BaseEnum are required to support `in` statements: +# 'ALL' in ConsistencyLevel == True +# 12345 in ConsistencyLevel == False +from enum import EnumMeta, Enum +from typing import Any + + +class MetaEnum(EnumMeta): + def __contains__(cls, item: Any) -> bool: + try: + # when item is type ConsistencyLevel + return item.name in cls.__members__.keys() + except AttributeError: + # when item is type str + return item in cls.__members__.keys() + + +class BaseEnum(Enum, metaclass=MetaEnum): + pass diff --git a/weaviate/util.py b/weaviate/util.py index 61df9a9e4..4cafc1609 100644 --- a/weaviate/util.py +++ b/weaviate/util.py @@ -9,7 +9,6 @@ import os import re import uuid as uuid_lib -from enum import Enum, EnumMeta from pathlib import Path from typing import Union, Sequence, Any, Optional, List, Dict, Generator, Tuple, cast @@ -26,6 +25,7 @@ WeaviateUnsupportedFeatureError, ) from weaviate.types import NUMBER, UUIDS, TIME +from weaviate.validator import _is_valid, _ExtraTypes from weaviate.warnings import _Warnings PYPI_PACKAGE_URL = "https://pypi.org/pypi/weaviate-client/json" @@ -36,23 +36,6 @@ BYTES_PER_CHUNK = 65535 # The number of bytes to read per chunk when encoding files ~ 64kb -# MetaEnum and BaseEnum are required to support `in` statements: -# 'ALL' in ConsistencyLevel == True -# 12345 in ConsistencyLevel == False -class MetaEnum(EnumMeta): - def __contains__(cls, item: Any) -> bool: - try: - # when item is type ConsistencyLevel - return item.name in cls.__members__.keys() - except AttributeError: - # when item is type str - return item in cls.__members__.keys() - - -class BaseEnum(Enum, metaclass=MetaEnum): - pass - - def image_encoder_b64(image_or_image_path: Union[str, io.BufferedReader]) -> str: """ Encode a image in a Weaviate understandable format from a binary read file or by providing @@ -461,7 +444,7 @@ def get_vector(vector: Sequence) -> List[float]: ) from None -def _get_vector_v4(vector: Sequence) -> List[float]: +def _get_vector_v4(vector: Any) -> List[float]: try: return get_vector(vector) except TypeError as e: @@ -978,3 +961,33 @@ def _datetime_from_weaviate_str(string: str) -> datetime.datetime: "".join(string.rsplit(":", 1) if string[-1] != "Z" else string), "%Y-%m-%dT%H:%M:%S%z", ) + + +def __is_list_type(inputs: Any) -> bool: + try: + if len(inputs) == 0: + return False + except TypeError: + return False + + return any( + _is_valid(types, inputs) + for types in [ + List, + _ExtraTypes.TF, + _ExtraTypes.PANDAS, + _ExtraTypes.NUMPY, + _ExtraTypes.POLARS, + ] + ) + + +def _is_1d_vector(inputs: Any) -> bool: + try: + if len(inputs) == 0: + return False + except TypeError: + return False + if __is_list_type(inputs): + return not __is_list_type(inputs[0]) # 2D vectors are not 1D vectors + return False diff --git a/weaviate/validator.py b/weaviate/validator.py index 1b8b44810..7fe11945c 100644 --- a/weaviate/validator.py +++ b/weaviate/validator.py @@ -2,6 +2,7 @@ from typing import Any, List, Sequence, Union, get_args, get_origin from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.str_enum import BaseEnum @dataclass @@ -11,6 +12,13 @@ class _ValidateArgument: value: Any +class _ExtraTypes(str, BaseEnum): + NUMPY = "numpy" + PANDAS = "pandas" + POLARS = "polars" + TF = "tensorflow" + + def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) -> None: """Validate the values of the input arguments in comparison to the expected types defined in _ValidateArgument. @@ -20,15 +28,21 @@ def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) - if isinstance(inputs, _ValidateArgument): inputs = [inputs] for validate in inputs: - if not any(__is_valid(exp, validate.value) for exp in validate.expected): + if not any(_is_valid(exp, validate.value) for exp in validate.expected): raise WeaviateInvalidInputError( f"Argument '{validate.name}' must be one of: {validate.expected}, but got {type(validate.value)}" ) -def __is_valid(expected: Any, value: Any) -> bool: +def _is_valid(expected: Any, value: Any) -> bool: if expected is None: return value is None + + # check for types that are not installed + # https://stackoverflow.com/questions/12569452/how-to-identify-numpy-types-in-python + if isinstance(expected, _ExtraTypes): + return expected.value in type(value).__module__ + expected_origin = get_origin(expected) if expected_origin is Union: args = get_args(expected)