Skip to content

Commit

Permalink
session.run_in_transaction returns the callback's return value. (#3753)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukesneeringer authored Aug 8, 2017
1 parent b96eaca commit fe757be
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 22 deletions.
10 changes: 5 additions & 5 deletions spanner/google/cloud/spanner/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,9 @@ def run_in_transaction(self, func, *args, **kw):
If passed, "timeout_secs" will be removed and used to
override the default timeout.
:rtype: :class:`datetime.datetime`
:returns: timestamp of committed transaction
:rtype: Any
:returns: The return value of ``func``.
:raises Exception:
reraises any non-ABORT execptions raised by ``func``.
"""
Expand All @@ -284,7 +285,7 @@ def run_in_transaction(self, func, *args, **kw):
if txn._transaction_id is None:
txn.begin()
try:
func(txn, *args, **kw)
return_value = func(txn, *args, **kw)
except GaxError as exc:
_delay_until_retry(exc, deadline)
del self._transaction
Expand All @@ -299,8 +300,7 @@ def run_in_transaction(self, func, *args, **kw):
_delay_until_retry(exc, deadline)
del self._transaction
else:
committed = txn.committed
return committed
return return_value


# pylint: disable=misplaced-bare-raise
Expand Down
4 changes: 2 additions & 2 deletions spanner/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from google.cloud.spanner import __version__


def _make_credentials():
def _make_credentials(): # pragma: NO COVER
import google.auth.credentials

class _CredentialsWithScopes(
Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(self, scopes=(), source=None):
self._scopes = scopes
self._source = source

def requires_scopes(self):
def requires_scopes(self): # pragma: NO COVER
return True

def with_scopes(self, scopes):
Expand Down
24 changes: 9 additions & 15 deletions spanner/tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,16 @@ def test_run_in_transaction_w_args_w_kwargs_wo_abort(self):
def unit_of_work(txn, *args, **kw):
called_with.append((txn, args, kw))
txn.insert(TABLE_NAME, COLUMNS, VALUES)
return 42

committed = session.run_in_transaction(
return_value = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')

self.assertEqual(committed, now)
self.assertIsNone(session._transaction)
self.assertEqual(len(called_with), 1)
txn, args, kw = called_with[0]
self.assertIsInstance(txn, Transaction)
self.assertEqual(txn.committed, committed)
self.assertEqual(return_value, 42)
self.assertEqual(args, ('abc',))
self.assertEqual(kw, {'some_arg': 'def'})

Expand Down Expand Up @@ -561,18 +561,15 @@ def test_run_in_transaction_w_abort_no_retry_metadata(self):
def unit_of_work(txn, *args, **kw):
called_with.append((txn, args, kw))
txn.insert(TABLE_NAME, COLUMNS, VALUES)
return 'answer'

committed = session.run_in_transaction(
return_value = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')

self.assertEqual(committed, now)
self.assertEqual(len(called_with), 2)
for index, (txn, args, kw) in enumerate(called_with):
self.assertIsInstance(txn, Transaction)
if index == 1:
self.assertEqual(txn.committed, committed)
else:
self.assertIsNone(txn.committed)
self.assertEqual(return_value, 'answer')
self.assertEqual(args, ('abc',))
self.assertEqual(kw, {'some_arg': 'def'})

Expand Down Expand Up @@ -621,17 +618,15 @@ def unit_of_work(txn, *args, **kw):
time_module = _FauxTimeModule()

with _Monkey(MUT, time=time_module):
committed = session.run_in_transaction(
unit_of_work, 'abc', some_arg='def')
session.run_in_transaction(unit_of_work, 'abc', some_arg='def')

self.assertEqual(time_module._slept,
RETRY_SECONDS + RETRY_NANOS / 1.0e9)
self.assertEqual(committed, now)
self.assertEqual(len(called_with), 2)
for index, (txn, args, kw) in enumerate(called_with):
self.assertIsInstance(txn, Transaction)
if index == 1:
self.assertEqual(txn.committed, committed)
self.assertEqual(txn.committed, now)
else:
self.assertIsNone(txn.committed)
self.assertEqual(args, ('abc',))
Expand Down Expand Up @@ -688,9 +683,8 @@ def unit_of_work(txn, *args, **kw):
time_module = _FauxTimeModule()

with _Monkey(MUT, time=time_module):
committed = session.run_in_transaction(unit_of_work)
session.run_in_transaction(unit_of_work)

self.assertEqual(committed, now)
self.assertEqual(time_module._slept,
RETRY_SECONDS + RETRY_NANOS / 1.0e9)
self.assertEqual(len(called_with), 2)
Expand Down

0 comments on commit fe757be

Please sign in to comment.