Skip to content

Commit

Permalink
chore: avoid checking instance on each stream call (#529)
Browse files Browse the repository at this point in the history
* chore: avoid checking instance on each stream call

* fixed indentation

* added check for unary call

* fixed type check

* fixed tests

* fixed coverage

* added exception to test class

* added comment to test

* 🦉 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

---------

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
daniel-sanche and gcf-owl-bot[bot] authored May 3, 2024
1 parent 7d87462 commit ab22afd
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 29 deletions.
24 changes: 10 additions & 14 deletions google/api_core/grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ class _WrappedStreamStreamCall(

def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)

@functools.wraps(callable_)
def error_remapped_callable(*args, **kwargs):
Expand All @@ -169,23 +168,13 @@ def error_remapped_callable(*args, **kwargs):
return error_remapped_callable


def _wrap_stream_errors(callable_):
def _wrap_stream_errors(callable_, wrapper_type):
"""Map errors for streaming RPC async callables."""
grpc_helpers._patch_callable_name(callable_)

@functools.wraps(callable_)
async def error_remapped_callable(*args, **kwargs):
call = callable_(*args, **kwargs)

if isinstance(call, aio.UnaryStreamCall):
call = _WrappedUnaryStreamCall().with_call(call)
elif isinstance(call, aio.StreamUnaryCall):
call = _WrappedStreamUnaryCall().with_call(call)
elif isinstance(call, aio.StreamStreamCall):
call = _WrappedStreamStreamCall().with_call(call)
else:
raise TypeError("Unexpected type of call %s" % type(call))

call = wrapper_type().with_call(call)
await call.wait_for_connection()
return call

Expand All @@ -207,10 +196,17 @@ def wrap_errors(callable_):
Returns: Callable: The wrapped gRPC callable.
"""
grpc_helpers._patch_callable_name(callable_)
if isinstance(callable_, aio.UnaryUnaryMultiCallable):
return _wrap_unary_errors(callable_)
elif isinstance(callable_, aio.UnaryStreamMultiCallable):
return _wrap_stream_errors(callable_, _WrappedUnaryStreamCall)
elif isinstance(callable_, aio.StreamUnaryMultiCallable):
return _wrap_stream_errors(callable_, _WrappedStreamUnaryCall)
elif isinstance(callable_, aio.StreamStreamMultiCallable):
return _wrap_stream_errors(callable_, _WrappedStreamStreamCall)
else:
return _wrap_stream_errors(callable_)
raise TypeError("Unexpected type of callable: {}".format(type(callable_)))


def create_channel(
Expand Down
78 changes: 63 additions & 15 deletions tests/asyncio/test_grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,40 @@ async def test_common_methods_in_wrapped_call():
assert mock_call.wait_for_connection.call_count == 1


@pytest.mark.asyncio
@pytest.mark.parametrize(
"callable_type,expected_wrapper_type",
[
(grpc.aio.UnaryStreamMultiCallable, grpc_helpers_async._WrappedUnaryStreamCall),
(grpc.aio.StreamUnaryMultiCallable, grpc_helpers_async._WrappedStreamUnaryCall),
(
grpc.aio.StreamStreamMultiCallable,
grpc_helpers_async._WrappedStreamStreamCall,
),
],
)
async def test_wrap_errors_w_stream_type(callable_type, expected_wrapper_type):
class ConcreteMulticallable(callable_type):
def __call__(self, *args, **kwargs):
raise NotImplementedError("Should not be called")

with mock.patch.object(
grpc_helpers_async, "_wrap_stream_errors"
) as wrap_stream_errors:
callable_ = ConcreteMulticallable()
grpc_helpers_async.wrap_errors(callable_)
assert wrap_stream_errors.call_count == 1
wrap_stream_errors.assert_called_once_with(callable_, expected_wrapper_type)


@pytest.mark.asyncio
async def test_wrap_stream_errors_unary_stream():
mock_call = mock.Mock(aio.UnaryStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedUnaryStreamCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -114,7 +142,9 @@ async def test_wrap_stream_errors_stream_unary():
mock_call = mock.Mock(aio.StreamUnaryCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamUnaryCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -126,22 +156,26 @@ async def test_wrap_stream_errors_stream_stream():
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
assert mock_call.wait_for_connection.call_count == 1


@pytest.mark.asyncio
async def test_wrap_stream_errors_type_error():
async def test_wrap_errors_type_error():
"""
If wrap_errors is called with an unexpected type, it should raise a TypeError.
"""
mock_call = mock.Mock()
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)

with pytest.raises(TypeError):
await wrapped_callable()
with pytest.raises(TypeError) as exc:
grpc_helpers_async.wrap_errors(multicallable)
assert "Unexpected type" in str(exc.value)


@pytest.mark.asyncio
Expand All @@ -151,7 +185,9 @@ async def test_wrap_stream_errors_raised():
mock_call.wait_for_connection = mock.AsyncMock(side_effect=[grpc_error])
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

with pytest.raises(exceptions.InvalidArgument):
await wrapped_callable()
Expand All @@ -166,7 +202,9 @@ async def test_wrap_stream_errors_read():
mock_call.read = mock.AsyncMock(side_effect=grpc_error)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

wrapped_call = await wrapped_callable(1, 2, three="four")
multicallable.assert_called_once_with(1, 2, three="four")
Expand All @@ -189,7 +227,9 @@ async def test_wrap_stream_errors_aiter():
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

with pytest.raises(exceptions.InvalidArgument) as exc_info:
Expand All @@ -210,7 +250,9 @@ async def test_wrap_stream_errors_aiter_non_rpc_error():
mock_call.__aiter__ = mock.Mock(return_value=mocked_aiter)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

with pytest.raises(TypeError) as exc_info:
Expand All @@ -224,7 +266,9 @@ async def test_wrap_stream_errors_aiter_called_multiple_times():
mock_call = mock.Mock(aio.StreamStreamCall, autospec=True)
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)
wrapped_call = await wrapped_callable()

assert wrapped_call.__aiter__() == wrapped_call.__aiter__()
Expand All @@ -239,7 +283,9 @@ async def test_wrap_stream_errors_write():
mock_call.done_writing = mock.AsyncMock(side_effect=[None, grpc_error])
multicallable = mock.Mock(return_value=mock_call)

wrapped_callable = grpc_helpers_async._wrap_stream_errors(multicallable)
wrapped_callable = grpc_helpers_async._wrap_stream_errors(
multicallable, grpc_helpers_async._WrappedStreamStreamCall
)

wrapped_call = await wrapped_callable()

Expand Down Expand Up @@ -295,7 +341,9 @@ def test_wrap_errors_streaming(wrap_stream_errors):
result = grpc_helpers_async.wrap_errors(callable_)

assert result == wrap_stream_errors.return_value
wrap_stream_errors.assert_called_once_with(callable_)
wrap_stream_errors.assert_called_once_with(
callable_, grpc_helpers_async._WrappedUnaryStreamCall
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit ab22afd

Please sign in to comment.