Skip to content

Commit

Permalink
add AsyncSingleton provider (#129)
Browse files Browse the repository at this point in the history
* add AsyncSingleton provider

* fix AsyncResource test

* fix typing issues

* use `is` operator in async provider test

* remove redundant `_threading_lock`
  • Loading branch information
zerlok authored Nov 21, 2024
1 parent 2fc167f commit 8a90e8e
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 2 deletions.
48 changes: 47 additions & 1 deletion tests/providers/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from that_depends import BaseContainer, providers


@dataclasses.dataclass(kw_only=True, slots=True)
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
class SingletonFactory:
dep1: str

Expand All @@ -21,9 +21,15 @@ class Settings(pydantic.BaseModel):
other_setting: str = "other_value"


async def create_async_obj(value: str) -> SingletonFactory:
await asyncio.sleep(0.001)
return SingletonFactory(dep1=f"async {value}")


class DIContainer(BaseContainer):
settings: Settings = providers.Singleton(Settings).cast
singleton = providers.Singleton(SingletonFactory, dep1=settings.some_setting)
singleton_async = providers.AsyncSingleton(create_async_obj, value=settings.some_setting)


async def test_singleton_provider() -> None:
Expand All @@ -39,6 +45,46 @@ async def test_singleton_provider() -> None:
await DIContainer.tear_down()


async def test_singleton_async_provider() -> None:
singleton1 = await DIContainer.singleton_async()
singleton2 = await DIContainer.singleton_async()
singleton3 = await DIContainer.singleton_async.async_resolve()
await DIContainer.singleton_async.tear_down()
singleton4 = await DIContainer.singleton_async.async_resolve()

assert singleton1 is singleton2 is singleton3
assert singleton4 is not singleton1

await DIContainer.tear_down()


async def test_singleton_async_provider_override() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")
singleton_async.override(SingletonFactory(dep1="bar"))

result = await singleton_async.async_resolve()
assert result == SingletonFactory(dep1="bar")


async def test_singleton_async_provider_concurrent() -> None:
singleton_async = providers.AsyncSingleton(create_async_obj, "foo")

results = await asyncio.gather(
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
singleton_async(),
)

assert all(val is results[0] for val in results)


async def test_singleton_async_provider_sync_resolve() -> None:
with pytest.raises(RuntimeError, match="AsyncSingleton cannot be resolved in an sync context."):
DIContainer.singleton_async.sync_resolve()


async def test_singleton_attr_getter() -> None:
singleton1 = await DIContainer.singleton()

Expand Down
3 changes: 2 additions & 1 deletion that_depends/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from that_depends.providers.object import Object
from that_depends.providers.resources import AsyncResource, Resource
from that_depends.providers.selector import Selector
from that_depends.providers.singleton import Singleton
from that_depends.providers.singleton import AsyncSingleton, Singleton


__all__ = [
Expand All @@ -28,5 +28,6 @@
"Resource",
"Selector",
"Singleton",
"AsyncSingleton",
"container_context",
]
41 changes: 41 additions & 0 deletions that_depends/providers/singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,44 @@ def sync_resolve(self) -> T_co:
async def tear_down(self) -> None:
if self._instance is not None:
self._instance = None


class AsyncSingleton(AbstractProvider[T_co]):
__slots__ = "_factory", "_args", "_kwargs", "_override", "_instance", "_asyncio_lock"

def __init__(self, factory: typing.Callable[P, typing.Awaitable[T_co]], *args: P.args, **kwargs: P.kwargs) -> None:
super().__init__()
self._factory: typing.Final[typing.Callable[P, typing.Awaitable[T_co]]] = factory
self._args: typing.Final[P.args] = args
self._kwargs: typing.Final[P.kwargs] = kwargs
self._instance: T_co | None = None
self._asyncio_lock: typing.Final = asyncio.Lock()

async def async_resolve(self) -> T_co:
if self._override is not None:
return typing.cast(T_co, self._override)

if self._instance is not None:
return self._instance

# lock to prevent resolving several times
async with self._asyncio_lock:
if self._instance is not None:
return self._instance

self._instance = await self._factory(
*[await x.async_resolve() if isinstance(x, AbstractProvider) else x for x in self._args],
**{
k: await v.async_resolve() if isinstance(v, AbstractProvider) else v
for k, v in self._kwargs.items()
},
)
return self._instance

def sync_resolve(self) -> typing.NoReturn:
msg = "AsyncSingleton cannot be resolved in an sync context."
raise RuntimeError(msg)

async def tear_down(self) -> None:
if self._instance is not None:
self._instance = None

0 comments on commit 8a90e8e

Please sign in to comment.