Skip to content

Commit

Permalink
fix: await on to_wrap in AsyncTransactional (#147)
Browse files Browse the repository at this point in the history
  • Loading branch information
rafilong authored Aug 7, 2020
1 parent 55da695 commit e640e66
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
18 changes: 9 additions & 9 deletions google/cloud/firestore_v1/async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,31 +188,31 @@ class _AsyncTransactional(_BaseTransactional):
:func:`~google.cloud.firestore_v1.async_transaction.transactional`.
Args:
to_wrap (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
A callable that should be run (and retried) in a transaction.
to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
A coroutine that should be run (and retried) in a transaction.
"""

def __init__(self, to_wrap) -> None:
super(_AsyncTransactional, self).__init__(to_wrap)

async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine:
"""Begin transaction and call the wrapped callable.
"""Begin transaction and call the wrapped coroutine.
If the callable raises an exception, the transaction will be rolled
If the coroutine raises an exception, the transaction will be rolled
back. If not, the transaction will be "ready" for ``Commit`` (i.e.
it will have staged writes).
Args:
transaction
(:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
A transaction to execute the callable within.
A transaction to execute the coroutine within.
args (Tuple[Any, ...]): The extra positional arguments to pass
along to the wrapped callable.
along to the wrapped coroutine.
kwargs (Dict[str, Any]): The extra keyword arguments to pass
along to the wrapped callable.
along to the wrapped coroutine.
Returns:
Any: result of the wrapped callable.
Any: result of the wrapped coroutine.
Raises:
Exception: Any failure caused by ``to_wrap``.
Expand All @@ -226,7 +226,7 @@ async def _pre_commit(self, transaction, *args, **kwargs) -> Coroutine:
if self.retry_id is None:
self.retry_id = self.current_id
try:
return self.to_wrap(transaction, *args, **kwargs)
return await self.to_wrap(transaction, *args, **kwargs)
except: # noqa
# NOTE: If ``rollback`` fails this will lose the information
# from the original failure.
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/v1/test_async_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def test_constructor(self):

@pytest.mark.asyncio
async def test__pre_commit_success(self):
to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[])
to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"totes-began"
Expand Down Expand Up @@ -368,7 +368,7 @@ async def test__pre_commit_success(self):
async def test__pre_commit_retry_id_already_set_success(self):
from google.cloud.firestore_v1.types import common

to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[])
to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[])
wrapped = self._make_one(to_wrap)
txn_id1 = b"already-set"
wrapped.retry_id = txn_id1
Expand Down Expand Up @@ -401,7 +401,7 @@ async def test__pre_commit_retry_id_already_set_success(self):
@pytest.mark.asyncio
async def test__pre_commit_failure(self):
exc = RuntimeError("Nope not today.")
to_wrap = mock.Mock(side_effect=exc, spec=[])
to_wrap = AsyncMock(side_effect=exc, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"gotta-fail"
Expand Down Expand Up @@ -438,7 +438,7 @@ async def test__pre_commit_failure_with_rollback_failure(self):
from google.api_core import exceptions

exc1 = ValueError("I will not be only failure.")
to_wrap = mock.Mock(side_effect=exc1, spec=[])
to_wrap = AsyncMock(side_effect=exc1, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"both-will-fail"
Expand Down Expand Up @@ -614,7 +614,7 @@ async def test__maybe_commit_failure_cannot_retry(self):

@pytest.mark.asyncio
async def test___call__success_first_attempt(self):
to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[])
to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"whole-enchilada"
Expand Down Expand Up @@ -650,7 +650,7 @@ async def test___call__success_second_attempt(self):
from google.cloud.firestore_v1.types import firestore
from google.cloud.firestore_v1.types import write

to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[])
to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"whole-enchilada"
Expand Down Expand Up @@ -707,7 +707,7 @@ async def test___call__failure(self):
_EXCEED_ATTEMPTS_TEMPLATE,
)

to_wrap = mock.Mock(return_value=mock.sentinel.result, spec=[])
to_wrap = AsyncMock(return_value=mock.sentinel.result, spec=[])
wrapped = self._make_one(to_wrap)

txn_id = b"only-one-shot"
Expand Down

0 comments on commit e640e66

Please sign in to comment.