Skip to content

Commit

Permalink
more test cov
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Jun 19, 2024
1 parent 5dac19f commit f14ea97
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 228 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ line-length = 120
skip-string-normalization = true

[tool.ruff]
preview = true
target-version = "py37"
line-length = 120

Expand Down Expand Up @@ -129,7 +130,7 @@ ban-relative-imports = "all"

[tool.ruff.lint.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]
"tests/**/*" = ["PLC2701", "S101", "TID252", "RUF029", "PLC2801"]

[tool.coverage.run]
source_pkgs = ["ninject", "tests"]
Expand All @@ -147,7 +148,6 @@ exclude_lines = [
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
'\.\.\.',
"raise NotImplementedError",
]
fail_under = 100
show_missing = true
Expand Down
208 changes: 93 additions & 115 deletions src/ninject/_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,16 @@
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from functools import wraps
from inspect import Parameter, isasyncgenfunction, iscoroutinefunction, isfunction, isgeneratorfunction, signature
from functools import cached_property, wraps
from inspect import (
Parameter,
currentframe,
isasyncgenfunction,
iscoroutinefunction,
isfunction,
isgeneratorfunction,
signature,
)
from typing import (
Annotated,
Any,
Expand All @@ -15,7 +23,6 @@
Awaitable,
Callable,
ContextManager,
Coroutine,
Generator,
Generic,
Iterator,
Expand Down Expand Up @@ -43,6 +50,18 @@
INJECTED = cast(Any, (type("INJECTED", (), {"__repr__": lambda _: "INJECTED"}))())


def get_caller_module_name(depth: int = 1) -> str | None:
frame = currentframe()
for _ in range(depth + 1):
try:
frame = frame.f_back
except AttributeError: # nocov
break
if frame is None:
return None # nocov
return frame.f_globals.get("__name__")


def add_dependency(cls: type) -> None:
_DEPENDENCIES.add(cls)

Expand All @@ -51,18 +70,6 @@ def is_dependency(cls: type) -> bool:
return cls in _DEPENDENCIES


def get_dependency_type_info(cls: type) -> tuple[Literal["attr", "item"] | None, dict[int | str, Any]]:
if is_dependency(cls):
return None, {}
elif get_origin(cls) is tuple:
return "item", {i: t for i, t in enumerate(get_args(cls)) if is_dependency(t)}
else:
return (
"item" if issubclass(cls, dict) and TypedDict in getattr(cls, "__orig_bases__", []) else "attr",
{k: t for k, t in get_type_hints(cls).items() if is_dependency(t)},
)


def get_provider_info(provider: Callable, provides_type: Any | None = None) -> ProviderInfo:
if provides_type is None:
return _infer_provider_info(provider)
Expand Down Expand Up @@ -90,7 +97,18 @@ def get_injected_dependency_types_from_callable(func: Callable[..., Any]) -> Map
return dependency_types


class SyncUniformContext(ContextManager[R], AsyncContextManager[R]):
class _BaseUniformContext:

var: ContextVar[R]
context_provider: SyncContextProvider[R]

def __repr__(self) -> str:
wrapped = _get_wrapped(self.context_provider)
provider_str = getattr(wrapped, "__qualname__", str(wrapped))
return f"{self.__class__.__name__}({self.var.name}, {provider_str})"


class SyncUniformContext(ContextManager[R], AsyncContextManager[R], _BaseUniformContext):
def __init__(
self,
var: ContextVar[R],
Expand Down Expand Up @@ -145,13 +163,8 @@ async def __aexit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
finally:
await async_exhaust_exits(self.dependency_contexts)

def __repr__(self) -> str:
wrapped = _get_wrapped(self.context_provider)
provider_str = getattr(wrapped, "__qualname__", str(wrapped))
return f"{self.__class__.__name__}({self.var.name}, {provider_str})"


class AsyncUniformContext(ContextManager[R], AsyncContextManager[R]):
class AsyncUniformContext(ContextManager[R], AsyncContextManager[R], _BaseUniformContext):
def __init__(
self,
var: ContextVar[R],
Expand Down Expand Up @@ -195,35 +208,58 @@ async def __aexit__(self, etype: Any, evalue: Any, atrace: Any, /) -> None:
finally:
await async_exhaust_exits(self.dependency_contexts)

def __repr__(self) -> str:
wrapped = _get_wrapped(self.context_provider)
provider_str = getattr(wrapped, "__qualname__", str(wrapped))
return f"{self.__class__.__name__}({self.var.name}, {provider_str})"


UniformContext: TypeAlias = "SyncUniformContext[R] | AsyncUniformContext[R]"
UniformContextProvider: TypeAlias = "Callable[[], UniformContext[R]]"


class _BaseProviderInfo(Generic[R]):

type: type[R]

@cached_property
def container_info(self) -> ContainerInfo | None:
if is_dependency(self.type):
return None
elif get_origin(self.type) is tuple:
container_type = "map"
dependencies = {i: t for i, t in enumerate(get_args(self.type)) if is_dependency(t)}
else:
container_type = "map" if _is_typed_dict(self.type) else "obj"
dependencies = {k: t for k, t in get_type_hints(self.type).items() if is_dependency(t)}

if not dependencies:
msg = f"Provided container {self.type} must contain at least one dependency"
raise TypeError(msg)

return ContainerInfo(kind=container_type, dependencies=dependencies)


@dataclass(kw_only=True)
class SyncProviderInfo(Generic[R]):
class SyncProviderInfo(_BaseProviderInfo[R]):
sync: Literal[True] = field(default=True, init=False)
uniform_context_type: type[SyncUniformContext[R]] = field(default=SyncUniformContext, init=False)
provides_type: type[R]
type: type[R]
context_provider: SyncContextProvider[R]


@dataclass(kw_only=True)
class AsyncProviderInfo(Generic[R]):
class AsyncProviderInfo(_BaseProviderInfo[R]):
sync: Literal[False] = field(default=False, init=False)
uniform_context_type: type[AsyncUniformContext[R]] = field(default=AsyncUniformContext, init=False)
provides_type: type[R]
type: type[R]
context_provider: AsyncContextProvider[R]


ProviderInfo = SyncProviderInfo[R] | AsyncProviderInfo[R]


@dataclass(kw_only=True)
class ContainerInfo:
kind: Literal["map", "obj"]
dependencies: dict[Any, type]


def asyncfunctioncontextmanager(func: Callable[[], Awaitable[R]]) -> AsyncContextProvider[R]:
return wraps(func)(lambda: AsyncFunctionContextManager(func))

Expand Down Expand Up @@ -320,65 +356,53 @@ def _get_context_manager_type(cls: type[ContextManager | AsyncContextManager]) -
provides_type = base_args[0]
break
else:
provides_type = get_provider_info(getattr(cls, method_name)).provides_type
provides_type = get_provider_info(getattr(cls, method_name)).type
return provides_type


def _get_provider_info(provider: Callable, provides_type: Any) -> ProviderInfo:
if isinstance(provider, type):
if issubclass(provider, ContextManager):
return SyncProviderInfo(provides_type=provides_type, context_provider=provider)
return SyncProviderInfo(type=provides_type, context_provider=provider)
elif issubclass(provider, AsyncContextManager):
return AsyncProviderInfo(provides_type=provides_type, context_provider=provider)
else:
msg = f"Unsupported provider type: {provider!r}"
raise TypeError(msg)
return AsyncProviderInfo(type=provides_type, context_provider=provider)
elif iscoroutinefunction(provider):
return AsyncProviderInfo(
provides_type=provides_type,
type=provides_type,
context_provider=asyncfunctioncontextmanager(ninject.inject(provider)),
)
elif isasyncgenfunction(provider):
return AsyncProviderInfo(
provides_type=provides_type,
type=provides_type,
context_provider=asynccontextmanager(ninject.inject(provider)),
)
elif isgeneratorfunction(provider):
return SyncProviderInfo(
provides_type=provides_type,
type=provides_type,
context_provider=contextmanager(ninject.inject(provider)),
)
elif isfunction(provider):
return SyncProviderInfo(
provides_type=provides_type,
type=provides_type,
context_provider=syncfunctioncontextmanager(ninject.inject(provider)),
)
else:
msg = f"Unsupported provider type: {provider!r}"
raise TypeError(msg)
msg = f"Unsupported provider type {provides_type!r} - expected a callable or context manager."
raise TypeError(msg)


def _infer_provider_info(provider: Any) -> ProviderInfo:
if isinstance(provider, type):
if issubclass(provider, ContextManager):
return SyncProviderInfo(
provides_type=_get_context_manager_type(provider),
context_provider=provider,
)
elif issubclass(provider, AsyncContextManager):
return AsyncProviderInfo(
provides_type=_get_context_manager_type(provider),
context_provider=provider,
)
if issubclass(provider, (ContextManager, AsyncContextManager)):
return _get_provider_info(provider, _get_context_manager_type(provider))
else:
msg = f"Unsupported provider type: {provider!r}"
msg = f"Unsupported provider type {provider!r} - expected a callable or context manager."
raise TypeError(msg)

try:
type_hints = get_type_hints(provider)
except TypeError:
msg = f"Expected a function or class, got {provider!r}"
raise TypeError(msg) from None
except TypeError as error: # nocov
msg = f"Unsupported provider type {provider!r} - expected a callable or context manager."
raise TypeError(msg) from error

return_type = _unwrap_annotated(type_hints.get("return"))
return_type_origin = get_origin(return_type)
Expand All @@ -388,71 +412,25 @@ def _infer_provider_info(provider: Any) -> ProviderInfo:
raise TypeError(msg)

if return_type_origin is None:
if iscoroutinefunction(provider):
return AsyncProviderInfo(
provides_type=return_type,
context_provider=asyncfunctioncontextmanager(ninject.inject(provider)),
)
else:
return SyncProviderInfo(
provides_type=return_type,
context_provider=syncfunctioncontextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, (AsyncIterator, AsyncGenerator)):
return AsyncProviderInfo(
provides_type=get_args(return_type)[0],
context_provider=asynccontextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, (Iterator, Generator)):
return SyncProviderInfo(
provides_type=get_args(return_type)[0],
context_provider=contextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, Awaitable):
return AsyncProviderInfo(
provides_type=get_args(return_type)[0],
context_provider=asyncfunctioncontextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, Coroutine):
coro_yield_type, _, cor_return_type = get_args(return_type)
if _unwrap_annotated(coro_yield_type) in (None, Any):
return AsyncProviderInfo(
provides_type=get_args(return_type)[0],
context_provider=asynccontextmanager(ninject.inject(provider)),
)
else:
return AsyncProviderInfo(
provides_type=cor_return_type,
context_provider=asyncfunctioncontextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, ContextManager):
return SyncProviderInfo(
provides_type=get_provider_info(return_type).provides_type,
context_provider=syncfunctioncontextmanager(ninject.inject(provider)),
)
elif issubclass(return_type_origin, AsyncContextManager):
return AsyncProviderInfo(
provides_type=get_provider_info(return_type).provides_type,
context_provider=asyncfunctioncontextmanager(ninject.inject(provider)),
)
elif isfunction(provider):
return SyncProviderInfo(
provides_type=return_type,
context_provider=syncfunctioncontextmanager(ninject.inject(provider)),
)
return _get_provider_info(provider, return_type)
elif issubclass(return_type_origin, (AsyncIterator, AsyncGenerator, Iterator, Generator)):
return _get_provider_info(provider, get_args(return_type)[0])
elif issubclass(return_type_origin, (ContextManager, AsyncContextManager)):
return _get_provider_info(provider, get_provider_info(return_type).type)
else:
msg = f"Unsupported provider type: {provider!r}"
raise TypeError(msg)
return _get_provider_info(provider, return_type)


def _unwrap_annotated(anno: Any) -> Any:
if get_origin(anno) is Annotated:
return get_args(anno)[0]
return anno
return get_args(anno)[0] if get_origin(anno) is Annotated else anno


_DEPENDENCY_VARS_BY_TYPE: WeakKeyDictionary[type, ContextVar] = WeakKeyDictionary()
_CONTEXT_PROVIDER_VARS_BY_TYPE: WeakKeyDictionary[type, ContextVar[UniformContextProvider]] = WeakKeyDictionary()


_DEPENDENCIES: set[type] = set()


def _is_typed_dict(t: type) -> bool:
return isinstance(t, type) and issubclass(t, dict) and TypedDict in getattr(t, "__orig_bases__", [])
Loading

0 comments on commit f14ea97

Please sign in to comment.