diff --git a/pyramid_tm/__init__.py b/pyramid_tm/__init__.py index a4eabfc..7708d95 100644 --- a/pyramid_tm/__init__.py +++ b/pyramid_tm/__init__.py @@ -76,8 +76,8 @@ def tm_tween(request): except: exc_info = sys.exc_info() try: - manager.abort() retryable = manager._retryable(*exc_info[:-1]) + manager.abort() if (number <= 0) or (not retryable): reraise(*exc_info) finally: diff --git a/pyramid_tm/tests.py b/pyramid_tm/tests.py index 6ede661..2ac6779 100644 --- a/pyramid_tm/tests.py +++ b/pyramid_tm/tests.py @@ -56,7 +56,7 @@ def setUp(self): self.request = DummyRequest() self.response = DummyResponse() self.registry = DummyRegistry() - + def _callFUT(self, handler=None, registry=None, request=None, txn=None): if handler is None: def handler(request): @@ -114,7 +114,7 @@ class Conflict(TransientError): def handler(request, count=count): raise Conflict self.assertRaises(Conflict, self._callFUT, handler=handler) - + def test_handler_isdoomed(self): txn = DummyTransaction(True) self._callFUT(txn=txn) @@ -236,13 +236,15 @@ def __init__(self, doomed=False, retryable=False): self.committed = 0 self.aborted = 0 self.retryable = retryable + self.active = False @property def manager(self): return self def _retryable(self, t, v): - return self.retryable + if self.active: + return self.retryable def get(self): return self @@ -255,12 +257,14 @@ def isDoomed(self): def begin(self): self.began+=1 + self.active = True return self def commit(self): self.committed+=1 def abort(self): + self.active = False self.aborted+=1 def note(self, value):