Skip to content

Commit

Permalink
fix type hints + other minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rmorshea committed Dec 7, 2024
1 parent a5aa7f8 commit 777fb18
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 73 deletions.
4 changes: 2 additions & 2 deletions docs/src/recipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def transaction_provider() -> Iterator[Transaction]:
yield session
```

## NameError in Type Hints
## Type Hint NameError

!!! note

Expand Down Expand Up @@ -271,7 +271,7 @@ def query_database(*, conn: Connection = required) -> None: ...
Type checkers should still be able to check the return type using the `provides`
argument so it may not be necessary to annotate it in the function signature.

## Pytest-Asyncio Issue
## Pytest-Asyncio

Under the hood, PyBooster uses `contextvars` to manage the state of providers and
injectors. If you use `pytest-asyncio` to write async tests it's likely you'll run into
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
"paramorator>=1.0.2,<2",
"rustworkx>=0.15,<0.16",
]
[project.optional-dependencies]
sqlalchemy = ["sqlalchemy[asyncio]>=2,<3"]

[project.urls]
Source = "https://github.com/rmorshea/pybooster"
Expand Down Expand Up @@ -85,6 +87,7 @@ line-length = 100

[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 80
quote-style = "double"
indent-style = "space"

Expand Down Expand Up @@ -165,6 +168,7 @@ ban-relative-imports = "all"
"TRY003", # Avoid specifying long messages outside the exception class
]


[tool.yamlfix]
line_length = 100

Expand Down
19 changes: 10 additions & 9 deletions src/pybooster/_private/_injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pybooster._private._solution import Solution
from pybooster._private._utils import start_future
from pybooster._private._utils import undefined
from pybooster.types import Hint

if TYPE_CHECKING:
from pybooster._private._provider import AsyncProviderInfo
Expand Down Expand Up @@ -94,11 +95,11 @@ async def async_inject_into_params(
def _inject_params_into_current_values(
param_vals: dict[str, Any],
param_deps: HintMap,
current_values: dict[type, Any],
current_values: dict[Hint, Any],
solution: Solution,
) -> None:
to_update: dict[type, Any] = {}
to_invalidate: set[type] = set()
to_update: dict[Hint, Any] = {}
to_invalidate: set[Hint] = set()
for name in param_deps.keys() & param_vals:
cls = param_deps[name]
if current_values.get(cls, undefined) is not (new_val := param_vals[name]):
Expand All @@ -115,7 +116,7 @@ def _inject_params_into_current_values(
def _inject_current_values_into_params(
param_vals: dict[str, Any],
missing_params: HintDict,
current_values: Mapping[type, Any],
current_values: Mapping[Hint, Any],
) -> None:
for name, cls in tuple(missing_params.items()):
if cls in current_values:
Expand All @@ -127,7 +128,7 @@ def _sync_inject_from_provider_values(
stack: FastStack,
param_vals: dict[str, Any],
missing_params: HintDict,
current_values: dict[type, Any],
current_values: dict[Hint, Any],
solution: Solution[SyncProviderInfo],
) -> None:
for exe_group in solution.execution_order_for(missing_params.values(), current_values):
Expand All @@ -140,7 +141,7 @@ async def _async_inject_from_provider_values(
stack: AsyncFastStack,
param_vals: dict[str, Any],
missing_params: HintDict,
current_values: dict[type, Any],
current_values: dict[Hint, Any],
solution: Solution[ProviderInfo],
) -> None:
for exe_group in solution.execution_order_for(missing_params.values(), current_values):
Expand Down Expand Up @@ -189,7 +190,7 @@ async def _async_inject_from_provider_values(
def _sync_enter_provider(
stack: FastStack | AsyncFastStack,
info: SyncProviderInfo,
current_values: Mapping[type, Any],
current_values: Mapping[Hint, Any],
) -> Any:
kwargs = {n: current_values[c] for n, c in info["required_parameters"].items()}
return info["getter"](stack.enter_context(info["producer"](**kwargs)))
Expand All @@ -198,10 +199,10 @@ def _sync_enter_provider(
async def _async_enter_provider(
stack: AsyncFastStack,
info: AsyncProviderInfo,
current_values: Mapping[type, Any],
current_values: Mapping[Hint, Any],
) -> Any:
kwargs = {n: current_values[c] for n, c in info["required_parameters"].items()}
return info["getter"](await stack.enter_async_context(info["producer"](**kwargs)))


_CURRENT_VALUES = ContextVar[Mapping[type, Any]]("CURRENT_VALUES", default={})
_CURRENT_VALUES = ContextVar[Mapping[Hint, Any]]("CURRENT_VALUES", default={})
40 changes: 20 additions & 20 deletions src/pybooster/_private/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from pybooster._private._utils import is_type
from pybooster.types import AsyncContextManagerCallable
from pybooster.types import ContextManagerCallable
from pybooster.types import Hint
from pybooster.types import HintMap
from pybooster.types import InferHint

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -35,27 +37,25 @@
class SyncProviderInfo(TypedDict):
is_sync: Literal[True]
producer: ContextManagerCallable[[], Any]
provides: type
provides: Hint
required_parameters: HintMap
getter: Callable[[Any], Any]


class AsyncProviderInfo(TypedDict):
is_sync: Literal[False]
producer: AsyncContextManagerCallable[[], Any]
provides: type
provides: Hint
required_parameters: HintMap
getter: Callable[[Any], Any]


ProviderInfo = SyncProviderInfo | AsyncProviderInfo


def get_provides_type(
provides: type[R] | Callable[..., type[R]], *args: Any, **kwargs: Any
) -> type[R]:
def get_provides_type(provides: Hint | Callable[..., Hint], *args: Any, **kwargs: Any) -> Hint:
if is_type(provides):
return cast("type[R]", provides)
return provides
elif callable(provides):
return provides(*args, **kwargs)
else:
Expand All @@ -66,31 +66,31 @@ def get_provides_type(

@overload
def get_provider_info(
producer: ContextManagerCallable[[], R],
provides: type[R] | Callable[[], type[R]],
producer: ContextManagerCallable[[], Any],
provides: Hint | InferHint,
required_params: HintMap,
*,
is_sync: Literal[True],
) -> Mapping[type, SyncProviderInfo]: ...
) -> Mapping[Hint, SyncProviderInfo]: ...


@overload
def get_provider_info(
producer: AsyncContextManagerCallable[[], R],
provides: type[R] | Callable[[], type[R]],
producer: AsyncContextManagerCallable[[], Any],
provides: Hint | InferHint,
required_params: HintMap,
*,
is_sync: Literal[False],
) -> Mapping[type, AsyncProviderInfo]: ...
) -> Mapping[Hint, AsyncProviderInfo]: ...


def get_provider_info(
producer: ContextManagerCallable[[], R] | AsyncContextManagerCallable[[], R],
provides: type[R] | Callable[[], type[R]],
producer: ContextManagerCallable[[], Any] | AsyncContextManagerCallable[[], Any],
provides: Hint | InferHint,
required_params: HintMap,
*,
is_sync: bool,
) -> Mapping[type, ProviderInfo]:
) -> Mapping[Hint, ProviderInfo]:
provides_type = get_provides_type(provides)
if get_origin(provides_type) is tuple:
return _get_tuple_provider_infos(producer, provides_type, required_params, is_sync=is_sync)
Expand All @@ -99,12 +99,12 @@ def get_provider_info(


def _get_tuple_provider_infos(
producer: AnyContextManagerCallable[R],
provides: type[R],
producer: AnyContextManagerCallable,
provides: Hint,
required_parameters: HintMap,
*,
is_sync: bool,
) -> dict[type, ProviderInfo]:
) -> dict[Hint, ProviderInfo]:
infos_list = (
_get_scalar_provider_infos(producer, provides, required_parameters, is_sync=is_sync),
*(
Expand All @@ -123,12 +123,12 @@ def _get_tuple_provider_infos(

def _get_scalar_provider_infos(
producer: AnyContextManagerCallable[R],
provides: type[R],
provides: Hint,
required_parameters: HintMap,
*,
is_sync: bool,
getter: Callable[[R], Any] = lambda x: x,
) -> dict[type, ProviderInfo]:
) -> dict[Hint, ProviderInfo]:
if get_origin(provides) is Union:
msg = f"Cannot provide a union type {provides}."
raise TypeError(msg)
Expand Down
27 changes: 14 additions & 13 deletions src/pybooster/_private/_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pybooster._private._provider import ProviderInfo
from pybooster._private._provider import SyncProviderInfo
from pybooster._private._utils import frozenclass
from pybooster.types import Hint
from pybooster.types import InjectionError
from pybooster.types import SolutionError

Expand All @@ -28,13 +29,13 @@

P = TypeVar("P", bound=ProviderInfo)

DependencySet = Set[type]
DependencyMap = Mapping[type, DependencySet]
DependencySet = Set[Hint]
DependencyMap = Mapping[Hint, DependencySet]


def set_solutions(
sync_infos: Mapping[type, SyncProviderInfo],
async_infos: Mapping[type, AsyncProviderInfo],
sync_infos: Mapping[Hint, SyncProviderInfo],
async_infos: Mapping[Hint, AsyncProviderInfo],
) -> Callable[[], None]:
full_infos = {**sync_infos, **async_infos}

Expand All @@ -48,7 +49,7 @@ def reset() -> None:
return reset


def _set_solution(var: ContextVar[Solution[P]], infos: Mapping[type, P]) -> Token[Solution[P]]:
def _set_solution(var: ContextVar[Solution[P]], infos: Mapping[Hint, P]) -> Token[Solution[P]]:
dep_map = {cls: set(info["required_parameters"].values()) for cls, info in infos.items()}
return var.set(Solution.from_infos_and_dependency_map(infos, dep_map))

Expand All @@ -57,25 +58,25 @@ def _set_solution(var: ContextVar[Solution[P]], infos: Mapping[type, P]) -> Toke
class Solution(Generic[P]):
"""A solution to the dependency graph."""

type_by_index: Mapping[int, type]
type_by_index: Mapping[int, Hint]
"""Mapping graph index to types."""
index_ordering: Sequence[Set[int]]
r"""Topologically sorted generations of type IDs."""
index_by_type: Mapping[type, int]
index_by_type: Mapping[Hint, int]
"""Mapping types to graph index."""
index_graph: PyDiGraph
"""A directed graph of type IDs."""
infos_by_index: Mapping[int, P]
"""Mapping graph index to provider infos."""
infos_by_type: Mapping[type, P]
infos_by_type: Mapping[Hint, P]
"""Mapping types to provider infos."""

@classmethod
def from_infos_and_dependency_map(
cls, infos_by_type: Mapping[type, P], deps_by_type: DependencyMap
cls, infos_by_type: Mapping[Hint, P], deps_by_type: DependencyMap
) -> Self:
type_by_index: dict[int, type] = {}
index_by_type: dict[type, int] = {}
type_by_index: dict[int, Hint] = {}
index_by_type: dict[Hint, int] = {}

index_graph = PyDiGraph()
for tp in deps_by_type:
Expand Down Expand Up @@ -104,14 +105,14 @@ def from_infos_and_dependency_map(
infos_by_type=infos_by_type,
)

def descendant_types(self, cls: type) -> Set[type]:
def descendant_types(self, cls: Hint) -> Set[Hint]:
type_by_index = self.type_by_index # avoid extra attribute accesses
if cls not in self.index_by_type:
return set()
return {type_by_index[i] for i in descendants(self.index_graph, self.index_by_type[cls])}

def execution_order_for(
self, include_types: Collection[type], exclude_types: Collection[type]
self, include_types: Collection[Hint], exclude_types: Collection[Hint]
) -> Sequence[Sequence[P]]:
index_by_type = self.index_by_type # avoid extra attribute accesses
try:
Expand Down
21 changes: 13 additions & 8 deletions src/pybooster/_private/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from anyio.abc import TaskGroup

from pybooster.types import Hint
from pybooster.types import HintMap
from pybooster.types import HintSeq

Expand Down Expand Up @@ -88,12 +89,16 @@ def get_required_parameters(
) -> HintMap:
match dependencies:
case None:
return _get_required_parameters(func)
return _get_required_parameter_types(func)
case Mapping():
params = _get_required_sig_parameters(func)
if (lpar := len(params)) > (ldep := len(dependencies)):
msg = f"Could not match {ldep} dependencies to {lpar} required parameters."
raise TypeError(msg)
return dependencies
case Sequence():
params = _get_required_sig_parameters(func)
if (lpar := len(params)) != (ldep := len(dependencies)):
if (lpar := len(params)) > (ldep := len(dependencies)):
msg = f"Could not match {ldep} dependencies to {lpar} required parameters."
raise TypeError(msg)
return dict(zip((p.name for p in params), dependencies, strict=False))
Expand All @@ -102,8 +107,8 @@ def get_required_parameters(
raise TypeError(msg)


def _get_required_parameters(func: Callable[P, R]) -> HintMap:
required_params: dict[str, type] = {}
def _get_required_parameter_types(func: Callable[P, R]) -> HintMap:
required_params: dict[str, Hint] = {}
hints = get_type_hints(func, include_extras=True)
for param in _get_required_sig_parameters(func):
check_is_required_type(hint := hints[param.name])
Expand Down Expand Up @@ -161,7 +166,7 @@ def is_builtin_type(anno: RawAnnotation) -> bool:


class DependencyInfo(TypedDict):
type: type
type: Hint
new: bool


Expand All @@ -182,14 +187,14 @@ def _recurse_type(cls: Any) -> Iterator[Any]:
yield from _recurse_type(arg)


def get_callable_return_type(func: Callable) -> type:
def get_callable_return_type(func: Callable) -> Hint:
anno = get_type_hints(func, include_extras=True).get("return", Any)
raw_anno = get_raw_annotation(anno)
check_is_not_builtin_type(raw_anno)
return anno


def get_coroutine_return_type(func: Callable) -> type:
def get_coroutine_return_type(func: Callable) -> Hint:
return_type = get_callable_return_type(func)
if get_origin(return_type) is Coroutine:
try:
Expand All @@ -201,7 +206,7 @@ def get_coroutine_return_type(func: Callable) -> type:
return return_type


def get_iterator_yield_type(func: Callable, *, sync: bool) -> type:
def get_iterator_yield_type(func: Callable, *, sync: bool) -> Hint:
return_type = get_callable_return_type(func)
if sync:
if get_origin(return_type) is not Iterator:
Expand Down
Loading

0 comments on commit 777fb18

Please sign in to comment.