Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parse the correct timeouts to httpx client #1206

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions mock_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time
from concurrent import futures
from typing import Generator, Mapping

Expand All @@ -8,11 +9,19 @@
from grpc_health.v1.health_pb2 import HealthCheckResponse, HealthCheckRequest
from grpc_health.v1.health_pb2_grpc import HealthServicer, add_HealthServicer_to_server
from pytest_httpserver import HTTPServer, HeaderValueMatcher
from werkzeug.wrappers import Response
from werkzeug.wrappers import Request, Response

import weaviate
from weaviate.connect.base import ConnectionParams, ProtocolParams
from weaviate.proto.v1 import properties_pb2, tenants_pb2, search_get_pb2, weaviate_pb2_grpc
from weaviate.proto.v1 import (
batch_pb2,
properties_pb2,
tenants_pb2,
search_get_pb2,
weaviate_pb2_grpc,
)

from mock_tests.mock_data import mock_class

MOCK_IP = "127.0.0.1"
MOCK_PORT = 23536
Expand Down Expand Up @@ -76,6 +85,26 @@ def weaviate_auth_mock(weaviate_mock: HTTPServer):
yield weaviate_mock


@pytest.fixture(scope="function")
def weaviate_timeouts_mock(weaviate_no_auth_mock: HTTPServer):
def slow_get(request: Request) -> Response:
time.sleep(1)
return Response(json.dumps({"doesn't": "matter"}), content_type="application/json")

def slow_post(request: Request) -> Response:
time.sleep(2)
return Response(json.dumps({"doesn't": "matter"}), content_type="application/json")

weaviate_no_auth_mock.expect_request(
f"/v1/schema/{mock_class['class']}", method="GET"
).respond_with_handler(slow_get)
weaviate_no_auth_mock.expect_request("/v1/objects", method="POST").respond_with_handler(
slow_post
)

yield weaviate_no_auth_mock


@pytest.fixture(scope="function")
def start_grpc_server() -> Generator[grpc.Server, None, None]:
# Create a gRPC server
Expand Down Expand Up @@ -110,6 +139,22 @@ def weaviate_client(
client.close()


@pytest.fixture(scope="function")
def weaviate_timeouts_client(
weaviate_timeouts_mock: HTTPServer, start_grpc_server: grpc.Server
) -> Generator[weaviate.WeaviateClient, None, None]:
client = weaviate.connect_to_local(
host=MOCK_IP,
port=MOCK_PORT,
grpc_port=MOCK_PORT_GRPC,
additional_config=weaviate.classes.init.AdditionalConfig(
timeout=weaviate.classes.init.Timeout(query=0.5, insert=1.5)
),
)
yield client
client.close()


@pytest.fixture(scope="function")
def tenants_collection(
weaviate_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
Expand Down Expand Up @@ -184,3 +229,24 @@ def Search(

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_client.collections.get("YearZeroCollection")


@pytest.fixture(scope="function")
def timeouts_collection(
weaviate_timeouts_client: weaviate.WeaviateClient, start_grpc_server: grpc.Server
) -> weaviate.collections.Collection:
class MockWeaviateService(weaviate_pb2_grpc.WeaviateServicer):
def Search(
self, request: search_get_pb2.SearchRequest, context: grpc.ServicerContext
) -> search_get_pb2.SearchReply:
time.sleep(1)
return search_get_pb2.SearchReply()

def BatchObjects(
self, request: batch_pb2.BatchObjectsRequest, context: grpc.ServicerContext
) -> batch_pb2.BatchObjectsReply:
time.sleep(2)
return batch_pb2.BatchObjectsReply()

weaviate_pb2_grpc.add_WeaviateServicer_to_server(MockWeaviateService(), start_grpc_server)
return weaviate_timeouts_client.collections.get(mock_class["class"])
69 changes: 69 additions & 0 deletions mock_tests/mock_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
mock_class = {
"class": "Something",
"description": "It's something!",
"invertedIndexConfig": {
"bm25": {"b": 0.8, "k1": 1.3},
"cleanupIntervalSeconds": 61,
"indexPropertyLength": True,
"indexTimestamps": True,
"stopwords": {"additions": None, "preset": "en", "removals": ["the"]},
},
"moduleConfig": {
"generative-openai": {},
"text2vec-contextionary": {"vectorizeClassName": True},
},
"multiTenancyConfig": {
"autoTenantActivation": False,
"autoTenantCreation": False,
"enabled": False,
},
"properties": [
{
"dataType": ["text[]"],
"indexFilterable": True,
"indexRangeFilters": False,
"indexSearchable": True,
"moduleConfig": {
"text2vec-contextionary": {"skip": False, "vectorizePropertyName": False}
},
"name": "names",
"tokenization": "word",
}
],
"replicationConfig": {"asyncEnabled": False, "factor": 1},
"shardingConfig": {
"virtualPerPhysical": 128,
"desiredCount": 1,
"actualCount": 1,
"desiredVirtualCount": 128,
"actualVirtualCount": 128,
"key": "_id",
"strategy": "hash",
"function": "murmur3",
},
"vectorIndexConfig": {
"skip": True,
"cleanupIntervalSeconds": 300,
"maxConnections": 64,
"efConstruction": 128,
"ef": -2,
"dynamicEfMin": 101,
"dynamicEfMax": 501,
"dynamicEfFactor": 9,
"vectorCacheMaxObjects": 1000000000001,
"flatSearchCutoff": 40001,
"distance": "cosine",
"pq": {
"enabled": True,
"bitCompression": True,
"segments": 1,
"centroids": 257,
"trainingLimit": 100001,
"encoder": {"type": "tile", "distribution": "normal"},
},
"bq": {"enabled": False},
"sq": {"enabled": False, "trainingLimit": 100000, "rescoreLimit": 20},
},
"vectorIndexType": "hnsw",
"vectorizer": "text2vec-contextionary",
}
25 changes: 25 additions & 0 deletions mock_tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
import weaviate
from weaviate.exceptions import WeaviateTimeoutError, WeaviateQueryError


def test_timeout_rest_query(timeouts_collection: weaviate.collections.Collection):
with pytest.raises(WeaviateTimeoutError):
timeouts_collection.config.get()


def test_timeout_rest_insert(timeouts_collection: weaviate.collections.Collection):
with pytest.raises(WeaviateTimeoutError):
timeouts_collection.data.insert(properties={"what": "ever"})


def test_timeout_grpc_query(timeouts_collection: weaviate.collections.Collection):
with pytest.raises(WeaviateQueryError) as recwarn:
timeouts_collection.query.fetch_objects()
assert "DEADLINE_EXCEEDED" in str(recwarn)


def test_timeout_grpc_insert(timeouts_collection: weaviate.collections.Collection):
with pytest.raises(WeaviateQueryError) as recwarn:
timeouts_collection.data.insert_many([{"what": "ever"}])
assert "DEADLINE_EXCEEDED" in str(recwarn)
1 change: 1 addition & 0 deletions weaviate/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ async def graphql_raw_query(self, gql_query: str) -> _RawGQLReturn:
weaviate_object=json_query,
error_msg="Raw GQL query failed",
status_codes=_ExpectedStatusCodes(ok_in=[200], error="GQL query"),
is_gql_query=True,
)

res = _decode_json_response_dict(response, "GQL query")
Expand Down
6 changes: 4 additions & 2 deletions weaviate/collections/batch/grpc_batch_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def pack_vector(vector: Any) -> bytes:
for obj in objects
]

async def objects(self, objects: List[_BatchObject], timeout: int) -> BatchObjectReturn:
async def objects(
self, objects: List[_BatchObject], timeout: Union[int, float]
) -> BatchObjectReturn:
"""Insert multiple objects into Weaviate through the gRPC API.

Parameters:
Expand Down Expand Up @@ -131,7 +133,7 @@ async def objects(self, objects: List[_BatchObject], timeout: int) -> BatchObjec
)

async def __send_batch(
self, batch: List[batch_pb2.BatchObject], timeout: int
self, batch: List[batch_pb2.BatchObject], timeout: Union[int, float]
) -> Dict[int, str]:
metadata = self._get_metadata()
try:
Expand Down
6 changes: 3 additions & 3 deletions weaviate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def __post_init__(self) -> None:
class Timeout(BaseModel):
"""Timeouts for the different operations in the client."""

query: int = Field(default=30, ge=0)
insert: int = Field(default=90, ge=0)
init: int = Field(default=2, ge=0)
query: Union[int, float] = Field(default=30, ge=0)
insert: Union[int, float] = Field(default=90, ge=0)
init: Union[int, float] = Field(default=2, ge=0)


class Proxies(BaseModel):
Expand Down
35 changes: 32 additions & 3 deletions weaviate/connect/v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
HTTPStatusError,
Limits,
ReadError,
ReadTimeout,
RemoteProtocolError,
RequestError,
Response,
Expand Down Expand Up @@ -55,6 +56,7 @@
WeaviateConnectionError,
WeaviateGRPCUnavailableError,
WeaviateStartUpError,
WeaviateTimeoutError,
)
from weaviate.proto.v1 import weaviate_pb2_grpc
from weaviate.util import (
Expand Down Expand Up @@ -219,9 +221,6 @@ def __make_mounts(self) -> Dict[str, AsyncHTTPTransport]:
def __make_async_client(self) -> AsyncClient:
return AsyncClient(
headers=self._headers,
timeout=Timeout(
None, connect=self.timeout_config.query, read=self.timeout_config.insert
),
mounts=self.__make_mounts(),
)

Expand Down Expand Up @@ -406,12 +405,37 @@ def __get_latest_headers(self) -> Dict[str, str]:
copied_headers.update({"authorization": self.get_current_bearer_token()})
return copied_headers

def __get_timeout(
self, method: Literal["DELETE", "GET", "HEAD", "PATCH", "POST", "PUT"], is_gql_query: bool
) -> Timeout:
"""
In this way, the client waits the `httpx` default of 5s when connecting to a socket (connect), writing chunks (write), and
acquiring a connection from the pool (pool), but a custom amount as specified for reading the response (read).

From the PoV of the user, a request is considered to be timed out if no response is received within the specified time.
They specify the times depending on how they expect Weaviate to behave. For example, a query might take longer than an insert or vice versa
but, in either case, the user only cares about how long it takes for a response to be received.

https://www.python-httpx.org/advanced/timeouts/
"""
timeout = None
if method == "DELETE" or method == "PATCH" or method == "PUT":
timeout = self.timeout_config.insert
elif method == "GET" or method == "HEAD":
timeout = self.timeout_config.query
elif method == "POST" and is_gql_query:
timeout = self.timeout_config.query
elif method == "POST" and not is_gql_query:
timeout = self.timeout_config.insert
return Timeout(timeout=5.0, read=timeout)

async def __send(
self,
method: Literal["DELETE", "GET", "HEAD", "PATCH", "POST", "PUT"],
url: str,
error_msg: str,
status_codes: Optional[_ExpectedStatusCodes],
is_gql_query: bool = False,
weaviate_object: Optional[JSONPayload] = None,
params: Optional[Dict[str, Any]] = None,
) -> Response:
Expand All @@ -427,6 +451,7 @@ async def __send(
json=weaviate_object,
params=params,
headers=self.__get_latest_headers(),
timeout=self.__get_timeout(method, is_gql_query),
)
res = await self._client.send(req)
if status_codes is not None and res.status_code not in status_codes.ok:
Expand All @@ -436,6 +461,8 @@ async def __send(
raise WeaviateClosedClientError() from e
except ConnectError as conn_err:
raise WeaviateConnectionError(error_msg) from conn_err
except ReadTimeout as read_err:
raise WeaviateTimeoutError(error_msg) from read_err
except Exception as e:
raise e

Expand Down Expand Up @@ -480,6 +507,7 @@ async def post(
params: Optional[Dict[str, Any]] = None,
error_msg: str = "",
status_codes: Optional[_ExpectedStatusCodes] = None,
is_gql_query: bool = False,
) -> Response:
return await self.__send(
"POST",
Expand All @@ -488,6 +516,7 @@ async def post(
params=params,
error_msg=error_msg,
status_codes=status_codes,
is_gql_query=is_gql_query,
)

async def put(
Expand Down
10 changes: 9 additions & 1 deletion weaviate/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ class WeaviateConnectionError(WeaviateBaseError):
"""Is raised when the connection to Weaviate fails."""

def __init__(self, message: str = "") -> None:
msg = f"""Connection to Weaviate failed. {message}"""
msg = f"""Connection to Weaviate failed. Details: {message}"""
super().__init__(msg)


Expand All @@ -327,3 +327,11 @@ class WeaviateUnsupportedFeatureError(WeaviateBaseError):
def __init__(self, feature: str, current: str, minimum: str) -> None:
msg = f"""{feature} is not supported by your connected server's Weaviate version. The current version is {current}, but the feature requires at least version {minimum}."""
super().__init__(msg)


class WeaviateTimeoutError(WeaviateBaseError):
"""Is raised when a request to Weaviate times out."""

def __init__(self, message: str = "") -> None:
msg = f"""The request to Weaviate timed out while awaiting a response. Try adjusting the timeout config for your client. Details: {message}"""
super().__init__(msg)