Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batch retry with tenants in v3 #1122

Merged
merged 10 commits into from
Jun 24, 2024
5 changes: 4 additions & 1 deletion integration/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def test_authentication_user_pw(
"""Test authentication using Resource Owner Password Credentials Grant (User + PW)."""
# testing for warnings can be flaky without this as there are open SSL conections
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning)

url = "http://127.0.0.1:" + port
assert is_auth_enabled(url)
Expand Down Expand Up @@ -211,6 +212,7 @@ def test_client_with_authentication_with_anon_weaviate(recwarn):
"""Test that we warn users when their client has auth enabled, but weaviate has only anon access."""
# testing for warnings can be flaky without this as there are open SSL conections
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning)

url = "http://127.0.0.1:" + ANON_PORT
assert not is_auth_enabled(url)
Expand All @@ -234,6 +236,7 @@ def test_bearer_token_without_refresh(recwarn):

# testing for warnings can be flaky without this as there are open SSL conections
warnings.filterwarnings(action="ignore", message="unclosed", category=ResourceWarning)
warnings.filterwarnings(action="ignore", message="Dep005", category=DeprecationWarning)

url = "http://127.0.0.1:" + WCS_PORT
assert is_auth_enabled(url)
Expand All @@ -250,7 +253,7 @@ def test_bearer_token_without_refresh(recwarn):
)
client.schema.delete_all() # no exception, client works

assert len(recwarn) == 1
assert len(recwarn) == 1, [wrn.message for wrn in recwarn]
w = recwarn.pop()
assert issubclass(w.category, UserWarning)
assert str(w.message).startswith("Auth002")
Expand Down
8 changes: 4 additions & 4 deletions mock_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ def weaviate_mock(ready_mock):


@pytest.fixture(scope="function")
def weaviate_no_auth_mock(ready_mock):
ready_mock.expect_request("/v1/meta").respond_with_json({"version": "1.16"})
ready_mock.expect_request("/v1/.well-known/openid-configuration").respond_with_response(
def weaviate_no_auth_mock(weaviate_mock):
weaviate_mock.expect_request("/v1/meta").respond_with_json({"version": "1.16"})
weaviate_mock.expect_request("/v1/.well-known/openid-configuration").respond_with_response(
Response(json.dumps({}), status=404)
)

yield ready_mock
yield weaviate_mock


@pytest.fixture(scope="function")
Expand Down
32 changes: 31 additions & 1 deletion mock_tests/test_automatic_retries.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import uuid
from typing import Optional

import pytest
import uuid
from werkzeug.wrappers import Request, Response

import weaviate
Expand Down Expand Up @@ -283,3 +283,33 @@ def callback_print_all(results: Optional[BatchResponse]):
# callback output for each object
print_output, err = capfd.readouterr()
assert print_output.count("\n") == n


def test_retries_with_tenant(weaviate_no_auth_mock):
tenant = "tenant"
first_try = True

def handler(request: Request):
nonlocal first_try
objects = request.json["objects"]
for obj in objects:
assert obj["tenant"] == tenant
obj["deprecations"] = None
if first_try == 0:
obj["result"] = {"errors": {"error": [{"message": "I'm an error message"}]}}
first_try = False
else:
obj["result"] = {}
return Response(json.dumps(objects))

weaviate_no_auth_mock.expect_request("/v1/batch/objects").respond_with_handler(handler)

client = weaviate.Client(url=MOCK_SERVER_URL)

n = 10
with client.batch(
weaviate_error_retries=WeaviateErrorRetryConf(number_retries=1),
) as batch:
for i in range(n):
batch.add_data_object({"name": "test" + str(i)}, "test", uuid.uuid4(), tenant=tenant)
weaviate_no_auth_mock.check_assertions()
20 changes: 14 additions & 6 deletions weaviate/batch/crud_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,9 +805,13 @@ def _readd_objects_after_timeout(
new_batch = ObjectsBatchRequest()
for obj in batch_request.get_request_body()["objects"]:
class_name = obj["class"]
tenant = obj.get("tenant", None)
uuid = obj["id"]
params = {"tenant": tenant} if tenant is not None else None

response_head = self._connection.head(
path="/objects/" + class_name + "/" + uuid,
params=params,
)

if response_head.status_code == 404:
Expand All @@ -822,6 +826,7 @@ def _readd_objects_after_timeout(
# object might already exist and needs to be overwritten in case of an update
response = self._connection.get(
path="/objects/" + class_name + "/" + uuid,
params=params,
)

obj_weav = _decode_json_response_dict(response, "Re-add objects")
Expand All @@ -834,6 +839,7 @@ def _readd_objects_after_timeout(
data_object=obj["properties"],
uuid=uuid,
vector=obj.get("vector", None),
tenant=tenant,
)
return new_batch

Expand Down Expand Up @@ -967,7 +973,9 @@ class NonExistingClass not present"
self._objects_throughput_frame
)

self._recommended_num_objects = max(round(obj_per_second * self._creation_time), 1)
self._recommended_num_objects = max(
round(obj_per_second * float(self._creation_time)), 1
)

res = _decode_json_response_list(response, "batch add objects")
assert res is not None
Expand Down Expand Up @@ -1064,7 +1072,7 @@ def create_references(self) -> list:
self._references_throughput_frame
)

self._recommended_num_references = round(ref_per_sec * self._creation_time)
self._recommended_num_references = round(ref_per_sec * float(self._creation_time))

res = _decode_json_response_list(response, "Create references")
assert res is not None
Expand Down Expand Up @@ -1171,7 +1179,7 @@ def _send_batch_requests(self, force_wait: bool) -> None:
)
self._recommended_num_objects = max(
min(
round(obj_per_second * self._creation_time),
round(obj_per_second * float(self._creation_time)),
self._recommended_num_objects + 250,
),
1,
Expand Down Expand Up @@ -1209,7 +1217,7 @@ def _send_batch_requests(self, force_wait: bool) -> None:
self._references_throughput_frame
)
self._recommended_num_references = min(
round(ref_per_sec * self._creation_time),
round(ref_per_sec * float(self._creation_time)),
self._recommended_num_references * 2,
)

Expand Down Expand Up @@ -1741,11 +1749,11 @@ def creation_time(self, value: Real) -> None:
_check_positive_num(value, "creation_time", Real)
if self._recommended_num_references is not None:
self._recommended_num_references = round(
self._recommended_num_references * value / self._creation_time
self._recommended_num_references * float(value) / float(self._creation_time)
)
if self._recommended_num_objects is not None:
self._recommended_num_objects = round(
self._recommended_num_objects * value / self._creation_time
self._recommended_num_objects * float(value) / float(self._creation_time)
)
self._creation_time = value
if self._batching_type:
Expand Down
6 changes: 3 additions & 3 deletions weaviate/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Dict, Optional

import requests
import validators # type: ignore
import validators

from weaviate import exceptions
from weaviate.exceptions import WeaviateStartUpError
Expand Down Expand Up @@ -184,8 +184,8 @@ def wait_till_listening(self) -> None:
def check_supported_platform() -> None:
if platform.system() in ["Windows"]:
raise WeaviateStartUpError(
f"{platform.system()} is not supported with EmbeddedDB. Please upvote this feature request if "
f"you want this: https://github.com/weaviate/weaviate/issues/3315"
f"""{platform.system()} is not supported with EmbeddedDB. Please upvote this feature request if you want
this: https://github.com/weaviate/weaviate/issues/3315""" # noqa: E231
)

def start(self) -> None:
Expand Down
10 changes: 6 additions & 4 deletions weaviate/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
import json
import os
import re
import uuid as uuid_lib
from enum import Enum, EnumMeta
from io import BufferedReader
from typing import Union, Sequence, Any, Optional, List, Dict, Tuple, cast

import requests
import uuid as uuid_lib
import validators # type: ignore
import validators
from requests.exceptions import JSONDecodeError

from weaviate.exceptions import (
Expand Down Expand Up @@ -199,8 +199,10 @@ def generate_local_beacon(
raise TypeError("Expected to_object_uuid of type str or uuid.UUID")

if class_name is None:
return {"beacon": f"weaviate://localhost/{uuid}"}
return {"beacon": f"weaviate://localhost/{_capitalize_first_letter(class_name)}/{uuid}"}
return {"beacon": f"weaviate://localhost/{uuid}"} # noqa: E231
return {
"beacon": f"weaviate://localhost/{_capitalize_first_letter(class_name)}/{uuid}" # noqa: E231
}


def _get_dict_from_object(object_: Union[str, dict]) -> dict:
Expand Down
Loading