Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch objects where filter #443

Merged
merged 6 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/docker-compose-azure.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8081:8081
restart: on-failure:0
Expand Down
4 changes: 2 additions & 2 deletions ci/docker-compose-cluster.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
version: '3.4'
services:
weaviate-node-1:
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
restart: on-failure:0
ports:
- "8087:8080"
Expand All @@ -25,7 +25,7 @@ services:
- '8080'
- --scheme
- http
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8088:8080
- 6061:6060
Expand Down
2 changes: 1 addition & 1 deletion ci/docker-compose-okta-cc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8082:8082
restart: on-failure:0
Expand Down
2 changes: 1 addition & 1 deletion ci/docker-compose-okta-users.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8083:8083
restart: on-failure:0
Expand Down
2 changes: 1 addition & 1 deletion ci/docker-compose-openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ services:
- '8086'
- --scheme
- http
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8086:8086
restart: on-failure:0
Expand Down
2 changes: 1 addition & 1 deletion ci/docker-compose-wcs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- 8085:8085
restart: on-failure:0
Expand Down
2 changes: 1 addition & 1 deletion ci/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ services:
- --scheme
- http
- --write-timeout=600s
image: semitechnologies/weaviate:1.21.0
image: semitechnologies/weaviate:1.21.1
ports:
- "8080:8080"
- "50051:50051"
Expand Down
55 changes: 54 additions & 1 deletion integration/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def client():
client.schema.create_class(
{
"class": "Test",
"properties": [{"name": "test", "dataType": ["Test"]}],
"properties": [
{"name": "test", "dataType": ["Test"]},
{"name": "name", "dataType": ["string"]},
],
"vectorizer": "none",
}
)
Expand All @@ -79,6 +82,56 @@ def test_add_data_object(client: weaviate.Client, uuid: Optional[UUID], vector:
assert has_batch_errors(response) is False, str(response)


def test_delete_objects(client: weaviate.Client):
with client.batch as batch:
batch.add_data_object(data_object={"name": "one"}, class_name="Test")
batch.add_data_object(data_object={"name": "two"}, class_name="Test")
batch.add_data_object(data_object={"name": "three"}, class_name="Test")
batch.add_data_object(data_object={"name": "four"}, class_name="Test")
batch.add_data_object(data_object={"name": "five"}, class_name="Test")

with client.batch as batch:
batch.delete_objects(
"Test",
where={
"path": ["name"],
"operator": "Equal",
"valueText": "one",
},
)
res = client.data_object.get()
names = [obj["properties"]["name"] for obj in res["objects"]]
assert "one" not in names

with client.batch as batch:
batch.delete_objects(
"Test",
where={
"path": ["name"],
"operator": "ContainsAny",
"valueTextArray": ["two", "three"],
},
)
res = client.data_object.get()
names = [obj["properties"]["name"] for obj in res["objects"]]
assert "two" not in names
assert "three" not in names

with client.batch as batch:
batch.delete_objects(
"Test",
where={
"path": ["name"],
"operator": "ContainsAll",
"valueTextArray": ["four", "five"],
},
)
res = client.data_object.get()
names = [obj["properties"]["name"] for obj in res["objects"]]
assert "four" in names
assert "five" in names


@pytest.mark.parametrize("from_object_uuid", [uuid.uuid4(), str(uuid.uuid4()), uuid.uuid4().hex])
@pytest.mark.parametrize("to_object_uuid", [uuid.uuid4().hex, uuid.uuid4(), str(uuid.uuid4())])
@pytest.mark.parametrize("to_object_class_name", [None, "Test"])
Expand Down
4 changes: 2 additions & 2 deletions integration/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import weaviate

GIT_HASH = "8172acb"
SERVER_VERSION = "1.21.0"
GIT_HASH = "5f2df4d"
SERVER_VERSION = "1.21.1"
NODE_NAME = "node1"
NUM_OBJECT = 10

Expand Down
4 changes: 2 additions & 2 deletions integration/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_get_data(client: weaviate.Client):

def test_get_data_with_where_contains_any(client: weaviate.Client):
"""Test GraphQL's Get clause with where filter."""
where_filter = {"path": ["size"], "operator": "ContainsAny", "valueIntList": [5]}
where_filter = {"path": ["size"], "operator": "ContainsAny", "valueIntArray": [5]}
result = client.query.get("Ship", ["name", "size"]).with_where(where_filter).do()
objects = get_objects_from_result(result)
assert len(objects) == 1 and objects[0]["name"] == "HMS British Name"
Expand All @@ -133,7 +133,7 @@ def test_get_data_with_where_contains_all(
where_filter = {
"path": ["description"],
"operator": "ContainsAll",
"valueStringList": value_string_list,
"valueStringArray": value_string_list,
}
result = client.query.get("Ship", ["name"]).with_where(where_filter).do()
objects = get_objects_from_result(result)
Expand Down
6 changes: 4 additions & 2 deletions test/data/test_crud_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,13 @@ def test_get_by_id(self):

mock_get = Mock(return_value="Test")
data_object.get = mock_get
data_object.get_by_id(uuid="UUID", additional_properties=["Test", "list"], with_vector=True)
data_object.get_by_id(
uuid="UUID", additional_properties=["Test", "Array"], with_vector=True
)
mock_get.assert_called_with(
uuid="UUID",
class_name=None,
additional_properties=["Test", "list"],
additional_properties=["Test", "Array"],
with_vector=True,
node_name=None,
consistency_level=None,
Expand Down
14 changes: 7 additions & 7 deletions test/gql/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def test___str__(self):
test_filter = {
"path": ["name"],
"operator": "ContainsAny",
"valueTextList": ["A", "B\n"],
"valueTextArray": ["A", "B\n"],
}
result = str(Where(test_filter))
self.assertEqual(
Expand All @@ -743,15 +743,15 @@ def test___str__(self):
test_filter = {
"path": ["name"],
"operator": "ContainsAll",
"valueStringList": ["A", '"B"'],
"valueStringArray": ["A", '"B"'],
}
result = str(Where(test_filter))
self.assertEqual(
'where: {path: ["name"] operator: ContainsAll valueString: ["A","\\"B\\""]} ',
str(result),
)

test_filter = {"path": ["name"], "operator": "ContainsAny", "valueIntList": [1, 2]}
test_filter = {"path": ["name"], "operator": "ContainsAny", "valueIntArray": [1, 2]}
result = str(Where(test_filter))
self.assertEqual(
'where: {path: ["name"] operator: ContainsAny valueInt: [1, 2]} ', str(result)
Expand All @@ -760,20 +760,20 @@ def test___str__(self):
test_filter = {
"path": ["name"],
"operator": "ContainsAny",
"valueStringList": "A",
"valueStringArray": "A",
}
with self.assertRaises(TypeError) as error:
str(Where(test_filter))
check_error_message(self, error, value_is_not_list_err("A", "valueStringList"))
check_error_message(self, error, value_is_not_list_err("A", "valueStringArray"))

test_filter = {
"path": ["name"],
"operator": "ContainsAll",
"valueTextList": "A",
"valueTextArray": "A",
}
with self.assertRaises(TypeError) as error:
str(Where(test_filter))
check_error_message(self, error, value_is_not_list_err("A", "valueTextList"))
check_error_message(self, error, value_is_not_list_err("A", "valueTextArray"))

test_filter = {
"path": ["name"],
Expand Down
53 changes: 52 additions & 1 deletion weaviate/batch/crud_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from weaviate.connect import Connection
from weaviate.data.replication import ConsistencyLevel
from weaviate.gql.filter import _find_value_type
from weaviate.types import UUID
from .requests import BatchRequest, ObjectsBatchRequest, ReferenceBatchRequest, BatchResponse
from ..cluster import Cluster
Expand Down Expand Up @@ -1308,7 +1309,7 @@ def delete_objects(
payload = {
"match": {
"class": class_name,
"where": where,
"where": _clean_delete_objects_where(where),
},
"output": output,
"dryRun": dry_run,
Expand Down Expand Up @@ -1795,3 +1796,53 @@ def _batch_create_error_handler(retry: int, max_retries: int, error: Exception)
flush=True,
)
time.sleep((retry + 1) * 2)


def _clean_delete_objects_where(where: dict) -> dict:
"""Converts the Python-defined where filter type into the Weaviate-defined
where filter type used in the Batch REST request endpoint.

Parameters
----------
where : dict
The Python-defined where filter.

Returns
-------
dict
The Weaviate-defined where filter.
"""
py_value_type = _find_value_type(where)
weaviate_value_type = _convert_value_type(py_value_type)
where[weaviate_value_type] = where.pop(py_value_type)
return where


def _convert_value_type(_type: str) -> str:
"""Converts the Python-defined where filter type into the Weaviate-defined
where filter type used in the Batch REST request endpoint.

Parameters
----------
_type : str
The Python-defined where filter type.

Returns
-------
str
The Weaviate-defined where filter type.
"""
if _type == "valueTextList":
return "valueTextArray"
elif _type == "valueStringList":
return "valueStringArray"
elif _type == "valueIntList":
return "valueIntArray"
elif _type == "valueNumberList":
return "valueNumberArray"
elif _type == "valueBooleanList":
return "valueBooleanList"
elif _type == "valueDateList":
return "valueDateArray"
else:
return _type
34 changes: 21 additions & 13 deletions weaviate/gql/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from weaviate.util import get_vector, _sanitize_str

VALUE_LIST_TYPES = {
"valueStringArray",
"valueTextArray",
"valueIntArray",
"valueNumberArray",
"valueBooleanArray",
"valueDateArray",
"valueStringList",
"valueTextList",
"valueIntList",
Expand Down Expand Up @@ -839,23 +845,23 @@ def __str__(self):
if self.value_type in ["valueInt", "valueNumber"]:
_check_is_not_list(self.value, self.value_type)
gql += f"{self.value}}}"
elif self.value_type in ["valueIntList", "valueNumberList"]:
elif self.value_type in ["valueIntArray", "valueNumberArray"]:
_check_is_list(self.value, self.value_type)
gql += f"{self.value}}}"
elif self.value_type in ["valueText", "valueString"]:
_check_is_not_list(self.value, self.value_type)
gql += f"{_sanitize_str(self.value)}}}"
elif self.value_type in ["valueTextList", "valueStringList"]:
elif self.value_type in ["valueTextArray", "valueStringArray"]:
_check_is_list(self.value, self.value_type)
val = [_sanitize_str(v) for v in self.value]
gql += f"{_render_list(val)}}}"
elif self.value_type == "valueBoolean":
_check_is_not_list(self.value, self.value_type)
gql += f"{_bool_to_str(self.value)}}}"
elif self.value_type == "valueBooleanList":
elif self.value_type == "valueBooleanArray":
_check_is_list(self.value, self.value_type)
gql += f"{_render_list(self.value)}}}"
elif self.value_type == "valueDateList":
elif self.value_type == "valueDateArray":
_check_is_list(self.value, self.value_type)
gql += f"{_render_list(self.value)}}}"
elif self.value_type == "valueGeoRange":
Expand All @@ -875,29 +881,31 @@ def __str__(self):


def _convert_value_type(_type: str) -> str:
"""Convert the value type to match `json` formatting required by Weaviate.
"""Convert the value type to match `json` formatting required by the Weaviate-defined
GraphQL endpoints. NOTE: This is crucially different to the Batch REST endpoints wherein
the where filter is also used.

Parameters
----------
_type : str
The type to be converted.
The Python-defined type to be converted.

Returns
-------
str
The string interpretation of the type in `json` format.
The string interpretation of the type in Weaviate-defined `json` format.
"""
if _type == "valueTextList":
if _type == "valueTextArray" or _type == "valueTextList":
return "valueText"
elif _type == "valueStringList":
elif _type == "valueStringArray" or _type == "valueStringList":
return "valueString"
elif _type == "valueIntList":
elif _type == "valueIntArray" or _type == "valueIntList":
return "valueInt"
elif _type == "valueNumberList":
elif _type == "valueNumberArray" or _type == "valueNumberList":
return "valueNumber"
elif _type == "valueBooleanList":
elif _type == "valueBooleanArray" or _type == "valueBooleanList":
return "valueBoolean"
elif _type == "valueDateList":
elif _type == "valueDateArray" or _type == "valueDateList":
return "valueDate"
else:
return _type
Expand Down