Skip to content

Commit

Permalink
Merge pull request #1145 from weaviate/better_type_checking
Browse files Browse the repository at this point in the history
Add validation for types that are not installed
  • Loading branch information
dirkkul committed Jul 3, 2024
2 parents 92a5f76 + 903d262 commit e59bf43
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 162 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,17 @@ jobs:
$WEAVIATE_126
]
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Download build artifact to append to release
uses: actions/download-artifact@v4
with:
name: weaviate-python-client-wheel
- run: |
pip install weaviate_client-*.whl
pip install pytest pytest-benchmark pytest-profiling grpcio grpcio-tools pytest-xdist
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
pip install -r requirements-devel.txt # install test dependencies
- name: free space
run: sudo rm -rf /usr/local/lib/android
- run: rm -r weaviate
Expand Down
104 changes: 1 addition & 103 deletions integration/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

import weaviate.classes as wvc
from integration.conftest import CollectionFactory, CollectionFactoryGet, _sanitize_collection_name
from integration.constants import WEAVIATE_LOGO_OLD_ENCODED, WEAVIATE_LOGO_NEW_ENCODED
from weaviate.collections.classes.batch import ErrorObject
Expand Down Expand Up @@ -51,8 +52,6 @@
)
from weaviate.types import UUID, UUIDS

import weaviate.classes as wvc

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a")
UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75")
Expand Down Expand Up @@ -863,107 +862,6 @@ def test_query_properties(collection_factory: CollectionFactory) -> None:
assert len(objects) == 0


def test_near_vector(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

full_objects = collection.query.near_vector(
banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True)
).objects
assert len(full_objects) == 4

objects_distance = collection.query.near_vector(
banana.vector["default"], distance=full_objects[2].metadata.distance
).objects
assert len(objects_distance) == 3

objects_distance = collection.query.near_vector(
banana.vector["default"], certainty=full_objects[2].metadata.certainty
).objects
assert len(objects_distance) == 3


def test_near_vector_limit(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

objs = collection.query.near_vector(banana.vector["default"], limit=2).objects
assert len(objs) == 2


def test_near_vector_offset(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
uuid_fruit = collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

objs = collection.query.near_vector(banana.vector["default"], offset=1).objects
assert len(objs) == 3
assert objs[0].uuid == uuid_fruit


def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="Name", data_type=DataType.TEXT),
Property(name="Count", data_type=DataType.INT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51})
collection.data.insert({"Name": "Banana", "Count": 72})
collection.data.insert({"Name": "car", "Count": 12})
collection.data.insert({"Name": "Mountain", "Count": 1})

banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True)

ret = collection.query.near_vector(
banana1.vector["default"],
group_by=GroupBy(
prop="name",
number_of_groups=4,
objects_per_group=10,
),
return_metadata=MetadataQuery(distance=True, certainty=True),
)

assert len(ret.objects) == 4
assert ret.objects[0].belongs_to_group == "Banana"
assert ret.objects[1].belongs_to_group == "Banana"
assert ret.objects[2].belongs_to_group == "car"
assert ret.objects[3].belongs_to_group == "Mountain"


def test_near_object(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
Expand Down
177 changes: 177 additions & 0 deletions integration/test_collection_near_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import uuid
from typing import Any

import numpy as np
import pandas as pd
import polars as pl
import pytest

from integration.conftest import CollectionFactory
from weaviate.collections.classes.config import (
Configure,
DataType,
Property,
)
from weaviate.collections.classes.grpc import (
GroupBy,
MetadataQuery,
)

UUID1 = uuid.UUID("806827e0-2b31-43ca-9269-24fa95a221f9")
UUID2 = uuid.UUID("8ad0d33c-8db1-4437-87f3-72161ca2a51a")
UUID3 = uuid.UUID("83d99755-9deb-4b16-8431-d1dff4ab0a75")


def test_near_vector(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

full_objects = collection.query.near_vector(
banana.vector["default"], return_metadata=MetadataQuery(distance=True, certainty=True)
).objects
assert len(full_objects) == 4

objects_distance = collection.query.near_vector(
banana.vector["default"], distance=full_objects[2].metadata.distance
).objects
assert len(objects_distance) == 3

objects_distance = collection.query.near_vector(
banana.vector["default"], certainty=full_objects[2].metadata.certainty
).objects
assert len(objects_distance) == 3


def test_near_vector_limit(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

objs = collection.query.near_vector(banana.vector["default"], limit=2).objects
assert len(objs) == 2


def test_near_vector_offset(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[Property(name="Name", data_type=DataType.TEXT)],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana = collection.data.insert({"Name": "Banana"})
uuid_fruit = collection.data.insert({"Name": "Fruit"})
collection.data.insert({"Name": "car"})
collection.data.insert({"Name": "Mountain"})

banana = collection.query.fetch_object_by_id(uuid_banana, include_vector=True)

objs = collection.query.near_vector(banana.vector["default"], offset=1).objects
assert len(objs) == 3
assert objs[0].uuid == uuid_fruit


def test_near_vector_group_by_argument(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
properties=[
Property(name="Name", data_type=DataType.TEXT),
Property(name="Count", data_type=DataType.INT),
],
vectorizer_config=Configure.Vectorizer.text2vec_contextionary(
vectorize_collection_name=False
),
)
uuid_banana1 = collection.data.insert({"Name": "Banana", "Count": 51})
collection.data.insert({"Name": "Banana", "Count": 72})
collection.data.insert({"Name": "car", "Count": 12})
collection.data.insert({"Name": "Mountain", "Count": 1})

banana1 = collection.query.fetch_object_by_id(uuid_banana1, include_vector=True)

ret = collection.query.near_vector(
banana1.vector["default"],
group_by=GroupBy(
prop="name",
number_of_groups=4,
objects_per_group=10,
),
return_metadata=MetadataQuery(distance=True, certainty=True),
)

assert len(ret.objects) == 4
assert ret.objects[0].belongs_to_group == "Banana"
assert ret.objects[1].belongs_to_group == "Banana"
assert ret.objects[2].belongs_to_group == "car"
assert ret.objects[3].belongs_to_group == "Mountain"


@pytest.mark.parametrize(
"near_vector", [[1, 0], [1.0, 0.0], np.array([1, 0]), pl.Series([1, 0]), pd.Series([1, 0])]
)
def test_near_vector_with_other_input(
collection_factory: CollectionFactory, near_vector: Any
) -> None:
collection = collection_factory(vectorizer_config=Configure.Vectorizer.none())

uuid1 = collection.data.insert({}, vector=[1, 0])
collection.data.insert({}, vector=[0, 1])

ret = collection.query.near_vector(
near_vector,
distance=0.1,
)
assert len(ret.objects) == 1
assert ret.objects[0].uuid == uuid1


@pytest.mark.parametrize(
"near_vector",
[
{"first": [1, 0], "second": [1, 0, 0]},
{"first": np.array([1, 0]), "second": [1, 0, 0]},
{"first": pl.Series([1, 0]), "second": [1, 0, 0]},
{"first": pd.Series([1, 0]), "second": [1, 0, 0]},
[np.array([1, 0]), [1, 0, 0]],
[pl.Series([1, 0]), [1, 0, 0]],
[pd.Series([1, 0]), [1, 0, 0]],
{"first": [1.0, 0.0], "second": [1.0, 0.0, 0.0]},
],
)
def test_near_vector_with_named_vector_other_input(
collection_factory: CollectionFactory, near_vector: Any
) -> None:
dummy = collection_factory("dummy")
if dummy._connection._weaviate_version.is_lower_than(1, 26, 0):
pytest.skip("Named vectors are supported in versions higher than 1.26.0")

collection = collection_factory(
vectorizer_config=[
Configure.NamedVectors.none("first"),
Configure.NamedVectors.none("second"),
]
)

uuid1 = collection.data.insert({}, vector={"first": [1, 0], "second": [1, 0, 0]})
collection.data.insert({}, vector={"first": [0, 1], "second": [0, 0, 1]})

ret = collection.query.near_vector(near_vector, distance=0.1, target_vector=["first", "second"])
assert len(ret.objects) == 1
assert ret.objects[0].uuid == uuid1
1 change: 1 addition & 0 deletions requirements-devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pytest-httpserver>=1.0.8

numpy>=1.24.4,<2.0.0
pandas>=2.0.3,<3.0.0
pandas-stubs>=2.0.3,<3.0.0
polars>=0.20.26,<0.21.0

mypy>=1.9.0,<2.0.0
Expand Down
38 changes: 38 additions & 0 deletions test/collection/test_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Any, List

import numpy as np
import pandas as pd
import polars as pl
import pytest

from weaviate.exceptions import WeaviateInvalidInputError
from weaviate.validator import _validate_input, _ValidateArgument, _ExtraTypes


@pytest.mark.parametrize(
"inputs,expected,error",
[
(1, [int], False),
(1.0, [int], True),
([1, 1], [List], False),
(np.array([1, 2, 3]), [_ExtraTypes.NUMPY], False),
(np.array([1, 2, 3]), [_ExtraTypes.NUMPY, List], False),
(np.array([1, 2, 3]), [List], True),
([1, 1], [List, _ExtraTypes.NUMPY], False),
(pd.array([1, 1]), [_ExtraTypes.PANDAS, List], False),
(pd.Series([1, 1]), [_ExtraTypes.PANDAS, List], False),
(pl.Series([1, 1]), [_ExtraTypes.POLARS, List], False),
(
pl.Series([1, 1]),
[_ExtraTypes.POLARS, _ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List],
False,
),
(pl.Series([1, 1]), [_ExtraTypes.PANDAS, _ExtraTypes.NUMPY, List], True),
],
)
def test_validator(inputs: Any, expected: List[Any], error: bool) -> None:
if error:
with pytest.raises(WeaviateInvalidInputError):
_validate_input(_ValidateArgument(expected=expected, name="test", value=inputs))
else:
_validate_input(_ValidateArgument(expected=expected, name="test", value=inputs))
2 changes: 1 addition & 1 deletion weaviate/collections/classes/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from weaviate.collections.classes.types import _WeaviateInput
from weaviate.proto.v1 import search_get_pb2
from weaviate.str_enum import BaseEnum
from weaviate.types import INCLUDE_VECTOR, UUID
from weaviate.util import BaseEnum


class HybridFusion(str, BaseEnum):
Expand Down
Loading

0 comments on commit e59bf43

Please sign in to comment.