Skip to content

Commit

Permalink
Adds error handling for the async client (#671)
Browse files Browse the repository at this point in the history
* Adds error handling for the async client

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Clean up

Signed-off-by: Elena Kolevska <elena@kolevska.com>

* Ruff

Signed-off-by: Elena Kolevska <elena@kolevska.com>

---------

Signed-off-by: Elena Kolevska <elena@kolevska.com>
Co-authored-by: Bernd Verst <github@bernd.dev>
  • Loading branch information
elena-kolevska and berndverst authored Feb 9, 2024
1 parent e73b731 commit 23fc4b1
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 25 deletions.
83 changes: 62 additions & 21 deletions dapr/aio/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor,
StreamStreamClientInterceptor,
AioRpcError,
)

from dapr.clients.exceptions import DaprInternalError
from dapr.clients.exceptions import DaprInternalError, DaprGrpcError
from dapr.clients.grpc._state import StateOptions, StateItem
from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -446,9 +447,12 @@ async def publish_event(
metadata=publish_metadata,
)

call = self._stub.PublishEvent(req, metadata=metadata)
# response is google.protobuf.Empty
await call
try:
call = self._stub.PublishEvent(req, metadata=metadata)
# response is google.protobuf.Empty
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(await call.initial_metadata())

Expand Down Expand Up @@ -491,9 +495,15 @@ async def get_state(

if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0:
raise ValueError('State store name cannot be empty')

req = api_v1.GetStateRequest(store_name=store_name, key=key, metadata=state_metadata)
call = self._stub.GetState(req, metadata=metadata)
response = await call

try:
call = self._stub.GetState(req, metadata=metadata)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return StateResponse(
data=response.data, etag=response.etag, headers=await call.initial_metadata()
)
Expand Down Expand Up @@ -542,8 +552,12 @@ async def get_bulk_state(
req = api_v1.GetBulkStateRequest(
store_name=store_name, keys=keys, parallelism=parallelism, metadata=states_metadata
)
call = self._stub.GetBulkState(req, metadata=metadata)
response = await call

try:
call = self._stub.GetBulkState(req, metadata=metadata)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

items = []
for item in response.items:
Expand Down Expand Up @@ -601,8 +615,12 @@ async def query_state(
if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0:
raise ValueError('State store name cannot be empty')
req = api_v1.QueryStateRequest(store_name=store_name, query=query, metadata=states_metadata)
call = self._stub.QueryStateAlpha1(req)
response = await call

try:
call = self._stub.QueryStateAlpha1(req)
response = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

results = []
for item in response.results:
Expand Down Expand Up @@ -691,9 +709,12 @@ async def save_state(
)

req = api_v1.SaveStateRequest(store_name=store_name, states=[state])
call = self._stub.SaveState(req, metadata=metadata)
await call
return DaprResponse(headers=await call.initial_metadata())
try:
call = self._stub.SaveState(req, metadata=metadata)
await call
return DaprResponse(headers=await call.initial_metadata())
except AioRpcError as e:
raise DaprInternalError(e.details()) from e

async def save_bulk_state(
self, store_name: str, states: List[StateItem], metadata: Optional[MetadataTuple] = None
Expand Down Expand Up @@ -749,8 +770,13 @@ async def save_bulk_state(
]

req = api_v1.SaveStateRequest(store_name=store_name, states=req_states)
call = self._stub.SaveState(req, metadata=metadata)
await call

try:
call = self._stub.SaveState(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def execute_state_transaction(
Expand Down Expand Up @@ -815,8 +841,13 @@ async def execute_state_transaction(
req = api_v1.ExecuteStateTransactionRequest(
storeName=store_name, operations=req_ops, metadata=transactional_metadata
)
call = self._stub.ExecuteStateTransaction(req, metadata=metadata)
await call

try:
call = self._stub.ExecuteStateTransaction(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def delete_state(
Expand Down Expand Up @@ -880,8 +911,13 @@ async def delete_state(
options=state_options,
metadata=state_metadata,
)
call = self._stub.DeleteState(req, metadata=metadata)
await call

try:
call = self._stub.DeleteState(req, metadata=metadata)
await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

return DaprResponse(headers=await call.initial_metadata())

async def get_secret(
Expand Down Expand Up @@ -1522,8 +1558,13 @@ async def get_metadata(self) -> GetMetadataResponse:
information about supported features in the form of component
capabilities.
"""
call = self._stub.GetMetadata(GrpcEmpty())
_resp = await call

try:
call = self._stub.GetMetadata(GrpcEmpty())
_resp = await call
except AioRpcError as err:
raise DaprGrpcError(err) from err

response: api_v1.GetMetadataResponse = _resp # type alias
# Convert to more pythonic formats
active_actors_count = {
Expand Down
69 changes: 65 additions & 4 deletions tests/clients/test_dapr_async_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@

from unittest.mock import patch

from google.rpc import status_pb2, code_pb2

from dapr.aio.clients.grpc.client import DaprGrpcClientAsync
from dapr.aio.clients import DaprClient
from dapr.clients.exceptions import DaprGrpcError
from dapr.proto import common_v1
from .fake_dapr_server import FakeDaprSidecar
from dapr.conf import settings
Expand Down Expand Up @@ -202,10 +205,18 @@ async def test_invoke_binding_no_create(self):

async def test_publish_event(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
resp = await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'haha')
resp = await dapr.publish_event(
pubsub_name='pubsub', topic_name='example', data=b'test_data'
)

self.assertEqual(2, len(resp.headers))
self.assertEqual(['haha'], resp.headers['hdata'])
self.assertEqual(['test_data'], resp.headers['hdata'])

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.publish_event(pubsub_name='pubsub', topic_name='example', data=b'test_data')

async def test_publish_event_with_content_type(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
Expand Down Expand Up @@ -292,12 +303,19 @@ async def test_get_save_delete_state(self):
self.assertEqual(resp.data, b'')
self.assertEqual(resp.etag, '')

# Check a DaprGrpcError is raised
self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError) as context:
await dapr.get_state(store_name='my_statestore', key='key||')

await dapr.delete_state(store_name='statestore', key=key)
resp = await dapr.get_state(store_name='statestore', key=key)
self.assertEqual(resp.data, b'')
self.assertEqual(resp.etag, '')

with self.assertRaises(Exception) as context:
with self.assertRaises(DaprGrpcError) as context:
await dapr.delete_state(
store_name='statestore', key=key, state_metadata={'must_delete': '1'}
)
Expand Down Expand Up @@ -359,7 +377,20 @@ async def test_transaction_then_get_states(self):
self.assertEqual(resp.items[1].key, another_key)
self.assertEqual(resp.items[1].data, to_bytes(another_value.upper()))

async def test_save_then_get_states(self):
self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.execute_state_transaction(
store_name='statestore',
operations=[
TransactionalStateOperation(key=key, data=value, etag='foo'),
TransactionalStateOperation(key=another_key, data=another_value),
],
transactional_metadata={'metakey': 'metavalue'},
)

async def test_bulk_save_then_get_states(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')

key = str(uuid.uuid4())
Expand Down Expand Up @@ -394,6 +425,27 @@ async def test_save_then_get_states(self):
self.assertEqual(resp.items[1].etag, '1')
self.assertEqual(resp.items[1].data, to_bytes(another_value.upper()))

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.save_bulk_state(
store_name='statestore',
states=[
StateItem(key=key, value=value, metadata={'capitalize': '1'}),
StateItem(key=another_key, value=another_value, etag='1'),
],
metadata=(('metakey', 'metavalue'),),
)

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.get_bulk_state(
store_name='statestore', keys=[key, another_key], states_metadata={'upper': '1'}
)

async def test_get_secret(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
key1 = 'key_1'
Expand Down Expand Up @@ -512,6 +564,15 @@ async def test_query_state(self):
self.assertEqual(resp.results[0].key, '3')
self.assertEqual(len(resp.results), 3)

self._fake_dapr_server.raise_exception_on_next_call(
status_pb2.Status(code=code_pb2.INVALID_ARGUMENT, message='my invalid argument message')
)
with self.assertRaises(DaprGrpcError):
await dapr.query_state(
store_name='statestore',
query=json.dumps({'filter': {}, 'page': {'limit': 2}}),
)

async def test_shutdown(self):
dapr = DaprGrpcClientAsync(f'{self.scheme}localhost:{self.server_port}')
await dapr.shutdown()
Expand Down

0 comments on commit 23fc4b1

Please sign in to comment.