Skip to content

Commit

Permalink
Merge branch 'separate-locking-errors'
Browse files Browse the repository at this point in the history
  • Loading branch information
alisaifee committed Apr 19, 2024
2 parents 224b926 + 9bb6009 commit 0179c65
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
13 changes: 13 additions & 0 deletions coredis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ class LockError(RedisError, ValueError):
# This was originally chosen to behave like threading.Lock.


class LockAcquisitionError(LockError):
"""Errors acquiring a lock"""


class LockReleaseError(LockError):
"""Errors releasing a lock"""


class LockExtensionError(LockError):
"""Errors extending a lock"""


class RedisClusterException(Exception):
"""Base exception for the RedisCluster client"""

Expand Down Expand Up @@ -316,6 +328,7 @@ class UnknownCommandError(ResponseError):

def __init__(self, message: str) -> None:
command_match = self.ERROR_REGEX.findall(message)

if command_match:
self.command = command_match.pop()
super().__init__(message)
Expand Down
40 changes: 31 additions & 9 deletions coredis/recipes/locks/lua_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

from coredis.client import Redis, RedisCluster
from coredis.commands import Script
from coredis.exceptions import LockError, ReplicationError
from coredis.exceptions import (
LockAcquisitionError,
LockError,
LockExtensionError,
LockReleaseError,
ReplicationError,
)
from coredis.tokens import PureToken
from coredis.typing import AnyStr, Generic, KeyT, Optional, StringT, Type, Union

Expand Down Expand Up @@ -121,6 +127,7 @@ def __init__(
self.blocking = blocking
self.blocking_timeout = blocking_timeout
self.local = contextvars.ContextVar[Optional[StringT]]("token", default=None)

if self.timeout and self.sleep > self.timeout:
raise LockError("'sleep' must be less than 'timeout'")

Expand All @@ -129,7 +136,7 @@ async def __aenter__(
) -> LuaLock[AnyStr]:
if await self.acquire():
return self
raise LockError("Could not acquire lock")
raise LockAcquisitionError("Could not acquire lock")

async def __aexit__(
self,
Expand All @@ -155,14 +162,19 @@ async def acquire(
blocking = self.blocking
blocking_timeout = self.blocking_timeout
stop_trying_at = None

if blocking_timeout is not None:
stop_trying_at = time.time() + blocking_timeout

while True:
if await self.__acquire(token, stop_trying_at):
self.local.set(token)

return True

if not blocking:
return False

if stop_trying_at is not None and time.time() > stop_trying_at:
return False
await asyncio.sleep(self.sleep)
Expand All @@ -171,11 +183,12 @@ async def release(self) -> None:
"""
Releases the already acquired lock
:raises: :exc:`~coredis.exceptions.LockError`
:raises: :exc:`~coredis.exceptions.LockReleaseError`
"""
expected_token = self.local.get()

if expected_token is None:
raise LockError("Cannot release an unlocked lock")
raise LockReleaseError("Cannot release an unlocked lock")
self.local.set(None)
await self.__release(expected_token)

Expand All @@ -186,12 +199,15 @@ async def extend(self, additional_time: float) -> bool:
:param additional_time: can be specified as an integer or a float, both
representing the number of seconds to add.
:raises: :exc:`~coredis.exceptions.LockError`
:raises: :exc:`~coredis.exceptions.LockExtensionError`
"""

if self.local.get() is None:
raise LockError("Cannot extend an unlocked lock")
raise LockExtensionError("Cannot extend an unlocked lock")

if self.timeout is None:
raise LockError("Cannot extend a lock with no timeout")
raise LockExtensionError("Cannot extend a lock with no timeout")

return await self.__extend(additional_time)

@property
Expand All @@ -200,8 +216,10 @@ def replication_factor(self) -> int:
Number of replicas the lock needs to replicate to, to be
considered acquired.
"""

if isinstance(self.client, RedisCluster):
return math.ceil(self.client.num_replicas_per_shard / 2)

return 0

async def __acquire(self, token: StringT, stop_trying_at: Optional[float]) -> bool:
Expand All @@ -228,6 +246,7 @@ async def __acquire(self, token: StringT, stop_trying_at: Optional[float]) -> bo
category=RuntimeWarning,
)
await self.client.delete([self.name])

return False
else:
return await self.client.set(
Expand All @@ -245,12 +264,14 @@ async def __release(self, expected_token: StringT) -> None:
expected_token,
)
):
raise LockError("Cannot release a lock that's no longer owned")
raise LockReleaseError("Cannot release a lock that's no longer owned")

async def __extend(self, additional_time: float) -> bool:
additional_time = int(additional_time * 1000)

if additional_time < 0:
return True

if not bool(
await self.lua_extend(
self.client,
Expand All @@ -259,5 +280,6 @@ async def __extend(self, additional_time: float) -> bool:
additional_time,
)
):
raise LockError("Cannot extend a lock that's no longer owned")
raise LockExtensionError("Cannot extend a lock that's no longer owned")

return True
6 changes: 6 additions & 0 deletions docs/source/api/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ General Exceptions
:no-inherited-members:
.. autoexception:: coredis.exceptions.LockError
:no-inherited-members:
.. autoexception:: coredis.exceptions.LockAcquisitionError
:no-inherited-members:
.. autoexception:: coredis.exceptions.LockReleaseError
:no-inherited-members:
.. autoexception:: coredis.exceptions.LockExtensionError
:no-inherited-members:
.. autoexception:: coredis.exceptions.NoKeyError
:no-inherited-members:
.. autoexception:: coredis.exceptions.PersistenceError
Expand Down

0 comments on commit 0179c65

Please sign in to comment.