Skip to content

Commit

Permalink
Add comment and tests with different libraries
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkkul committed Jul 3, 2024
1 parent 71df74b commit 903d262
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
11 changes: 9 additions & 2 deletions integration/test_collection_near_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any

import numpy as np
import pandas as pd
import polars as pl
import pytest

from integration.conftest import CollectionFactory
Expand Down Expand Up @@ -121,7 +123,9 @@ def test_near_vector_group_by_argument(collection_factory: CollectionFactory) ->
assert ret.objects[3].belongs_to_group == "Mountain"


@pytest.mark.parametrize("near_vector", [[1, 0], [1.0, 0.0], np.array([1, 0])])
@pytest.mark.parametrize(
"near_vector", [[1, 0], [1.0, 0.0], np.array([1, 0]), pl.Series([1, 0]), pd.Series([1, 0])]
)
def test_near_vector_with_other_input(
collection_factory: CollectionFactory, near_vector: Any
) -> None:
Expand All @@ -143,8 +147,11 @@ def test_near_vector_with_other_input(
[
{"first": [1, 0], "second": [1, 0, 0]},
{"first": np.array([1, 0]), "second": [1, 0, 0]},
{"first": np.array([1, 0]), "second": [1, 0, 0]},
{"first": pl.Series([1, 0]), "second": [1, 0, 0]},
{"first": pd.Series([1, 0]), "second": [1, 0, 0]},
[np.array([1, 0]), [1, 0, 0]],
[pl.Series([1, 0]), [1, 0, 0]],
[pd.Series([1, 0]), [1, 0, 0]],
{"first": [1.0, 0.0], "second": [1.0, 0.0, 0.0]},
],
)
Expand Down
1 change: 1 addition & 0 deletions requirements-devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pytest-httpserver>=1.0.8

numpy>=1.24.4,<2.0.0
pandas>=2.0.3,<3.0.0
pandas-stubs>=2.0.3,<3.0.0
polars>=0.20.26,<0.21.0

mypy>=1.9.0,<2.0.0
Expand Down
3 changes: 3 additions & 0 deletions weaviate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) -
def _is_valid(expected: Any, value: Any) -> bool:
if expected is None:
return value is None

# check for types that are not installed
# https://stackoverflow.com/questions/12569452/how-to-identify-numpy-types-in-python
if isinstance(expected, _ExtraTypes):
return expected.value in type(value).__module__

Expand Down

0 comments on commit 903d262

Please sign in to comment.