Skip to content

Commit

Permalink
Revert "misc cleanup"
Browse files Browse the repository at this point in the history
This reverts commit 7f48a45.
  • Loading branch information
rmorshea committed Jun 13, 2024
1 parent 7f48a45 commit 33bad5c
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 266 deletions.
228 changes: 218 additions & 10 deletions src/ninject/_private/_inspect.py → src/ninject/_private.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import sys
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from functools import wraps
from inspect import Parameter, isasyncgenfunction, iscoroutinefunction, isfunction, isgeneratorfunction, signature
from typing import (
Annotated,
Expand All @@ -19,22 +22,18 @@
Literal,
Mapping,
ParamSpec,
Sequence,
TypeAlias,
TypedDict,
TypeVar,
cast,
get_args,
get_origin,
get_type_hints,
)
from weakref import WeakKeyDictionary

import ninject
from ninject._private._contexts import (
AsyncUniformContext,
SyncUniformContext,
asyncfunctioncontextmanager,
syncfunctioncontextmanager,
)
from ninject._private._global import has_dependency
from ninject.types import AsyncContextProvider, SyncContextProvider

P = ParamSpec("P")
Expand All @@ -44,15 +43,23 @@
INJECTED = cast(Any, (type("INJECTED", (), {"__repr__": lambda _: "INJECTED"}))())


def add_dependency(cls: type) -> None:
_DEPENDENCIES.add(cls)


def is_dependency(cls: type) -> bool:
return cls in _DEPENDENCIES


def get_dependency_type_info(cls: type) -> tuple[Literal["attr", "item"] | None, dict[int | str, Any]]:
if has_dependency(cls):
if is_dependency(cls):
return None, {}
elif get_origin(cls) is tuple:
return "item", {i: t for i, t in enumerate(get_args(cls)) if has_dependency(t)}
return "item", {i: t for i, t in enumerate(get_args(cls)) if is_dependency(t)}
else:
return (
"item" if issubclass(cls, dict) and TypedDict in getattr(cls, "__orig_bases__", []) else "attr",
{k: t for k, t in get_type_hints(cls).items() if has_dependency(t)},
{k: t for k, t in get_type_hints(cls).items() if is_dependency(t)},
)


Expand Down Expand Up @@ -83,6 +90,121 @@ def get_injected_dependency_types_from_callable(func: Callable[..., Any]) -> Map
return dependency_types


class SyncUniformContext(ContextManager[R], AsyncContextManager[R]):
def __init__(
self,
var: ContextVar[R],
context_provider: SyncContextProvider[R],
dependencies: Sequence[type],
):
self.var = var
self.context_provider = context_provider
self.token = None
self.dependencies = dependencies
self.dependency_contexts: list[UniformContext] = []

def __enter__(self) -> R:
try:
return self.var.get()
except LookupError:
for cls in self.dependencies:
(dependency_context := get_context_provider(cls)()).__enter__()
self.dependency_contexts.append(dependency_context)
self.context = context = self.context_provider()
self.token = self.var.set(context.__enter__())
return self.var.get()

def __exit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
if self.token is not None:
try:
self.var.reset(self.token)
finally:
try:
self.context.__exit__(etype, evalue, atrace)
finally:
exhaust_exits(self.dependency_contexts)

async def __aenter__(self) -> R:
try:
return self.var.get()
except LookupError:
for var in self.dependencies:
await (dependency_context := get_context_provider(var)()).__aenter__()
self.dependency_contexts.append(dependency_context)
self.context = context = self.context_provider()
self.token = self.var.set(context.__enter__())
return self.var.get()

async def __aexit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
if self.token is not None:
try:
self.var.reset(self.token)
finally:
try:
self.context.__exit__(etype, evalue, atrace)
finally:
await async_exhaust_exits(self.dependency_contexts)

def __repr__(self) -> str:
wrapped = _get_wrapped(self.context_provider)
provider_str = getattr(wrapped, "__qualname__", str(wrapped))
return f"{self.__class__.__name__}({self.var.name}, {provider_str})"


class AsyncUniformContext(ContextManager[R], AsyncContextManager[R]):
def __init__(
self,
var: ContextVar[R],
context_provider: AsyncContextProvider[R],
dependencies: Sequence[type],
):
self.var = var
self.context_provider = context_provider
self.token = None
self.dependencies = dependencies
self.dependency_contexts: list[UniformContext[Any]] = []

def __enter__(self) -> R:
try:
return self.var.get()
except LookupError:
msg = f"Cannot use an async provider {self.var.name} in a sync context"
raise RuntimeError(msg) from None

def __exit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
pass

async def __aenter__(self) -> R:
try:
return self.var.get()
except LookupError:
for cls in self.dependencies:
await (dependency_context := get_context_provider(cls)()).__aenter__()
self.dependency_contexts.append(dependency_context)
self.context = context = self.context_provider()
self.token = self.var.set(await context.__aenter__())
return self.var.get()

async def __aexit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
if self.token is not None:
try:
self.var.reset(self.token)
finally:
try:
await self.context.__aexit__(etype, evalue, atrace)
finally:
await async_exhaust_exits(self.dependency_contexts)

def __repr__(self) -> str:
wrapped = _get_wrapped(self.context_provider)
provider_str = getattr(wrapped, "__qualname__", str(wrapped))
return f"{self.__class__.__name__}({self.var.name}, {provider_str})"


UniformContext: TypeAlias = "SyncUniformContext[R] | AsyncUniformContext[R]"
UniformContextProvider: TypeAlias = "Callable[[], UniformContext[R]]"


@dataclass(kw_only=True)
class SyncProviderInfo(Generic[R]):
sync: Literal[True] = field(default=True, init=False)
Expand All @@ -102,6 +224,85 @@ class AsyncProviderInfo(Generic[R]):
ProviderInfo = SyncProviderInfo[R] | AsyncProviderInfo[R]


def asyncfunctioncontextmanager(func: Callable[[], Awaitable[R]]) -> AsyncContextProvider[R]:
return wraps(func)(lambda: AsyncFunctionContextManager(func))


def syncfunctioncontextmanager(func: Callable[[], R]) -> SyncContextProvider[R]:
return wraps(func)(lambda: SyncFunctionContextManager(func))


class AsyncFunctionContextManager(AsyncContextManager[R]):
def __init__(self, func: Callable[[], Awaitable[R]]) -> None:
self.func = func

async def __aenter__(self) -> R:
return await self.func()

async def __aexit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
pass


class SyncFunctionContextManager(ContextManager[R]):
def __init__(self, func: Callable[[], R]) -> None:
self.func = func

def __enter__(self) -> R:
return self.func()

def __exit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
pass


def exhaust_exits(ctxts: Sequence[ContextManager]) -> None:
if not ctxts:
return
try:
c, *ctxts = ctxts
c.__exit__(*sys.exc_info())
except Exception:
exhaust_exits(ctxts)
raise
else:
exhaust_exits(ctxts)


async def async_exhaust_exits(ctxts: Sequence[AsyncContextManager[Any]]) -> None:
if not ctxts:
return
try:
c, *ctxts = ctxts
await c.__aexit__(*sys.exc_info())
except Exception:
await async_exhaust_exits(ctxts)
raise
else:
await async_exhaust_exits(ctxts)


def set_context_provider(cls: type[R], provider: UniformContextProvider[R]) -> Callable[[], None]:
if not (context_provider_var := _CONTEXT_PROVIDER_VARS_BY_TYPE.get(cls)):
context_provider_var = _CONTEXT_PROVIDER_VARS_BY_TYPE[cls] = ContextVar(f"{cls.__name__}_provider")

token = context_provider_var.set(provider)
return lambda: context_provider_var.reset(token)


def get_context_provider(cls: type[R]) -> UniformContextProvider[R]:
try:
context_provider_var = _CONTEXT_PROVIDER_VARS_BY_TYPE[cls]
except KeyError:
msg = f"No provider declared for {cls}"
raise RuntimeError(msg) from None
return context_provider_var.get()


def setdefault_context_var(cls: type[R]) -> ContextVar:
if not (context_var := _DEPENDENCY_VARS_BY_TYPE.get(cls)):
context_var = _DEPENDENCY_VARS_BY_TYPE[cls] = ContextVar(f"{cls.__name__}_dependency")
return context_var


def _get_wrapped(func: Callable[P, R]) -> Callable[P, R]:
while maybe_func := getattr(func, "__wrapped__", None):
func = maybe_func
Expand Down Expand Up @@ -248,3 +449,10 @@ def _unwrap_annotated(anno: Any) -> Any:
if get_origin(anno) is Annotated:
return get_args(anno)[0]
return anno


_DEPENDENCY_VARS_BY_TYPE: WeakKeyDictionary[type, ContextVar] = WeakKeyDictionary()
_CONTEXT_PROVIDER_VARS_BY_TYPE: WeakKeyDictionary[type, ContextVar[UniformContextProvider]] = WeakKeyDictionary()


_DEPENDENCIES: set[type] = set()
Empty file removed src/ninject/_private/__init__.py
Empty file.
Loading

0 comments on commit 33bad5c

Please sign in to comment.