Skip to content

Commit

Permalink
Refactor db_injector decorator (#16390)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpieters authored Dec 15, 2024
1 parent f4a26c5 commit b42d3a1
Show file tree
Hide file tree
Showing 8 changed files with 440 additions and 145 deletions.
300 changes: 248 additions & 52 deletions src/prefect/server/database/dependencies.py

Large diffs are not rendered by default.

13 changes: 7 additions & 6 deletions src/prefect/server/events/ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pendulum
import sqlalchemy as sa
from cachetools import TTLCache
from typing_extensions import Self

from prefect.logging import get_logger
from prefect.server.database.dependencies import db_injector
Expand Down Expand Up @@ -75,7 +74,9 @@ async def record_event_as_seen(self, event: ReceivedEvent) -> None:
self._seen_events[self.scope][event.id] = True

@db_injector
async def record_follower(db: PrefectDBInterface, self: Self, event: ReceivedEvent):
async def record_follower(
self, db: PrefectDBInterface, event: ReceivedEvent
) -> None:
"""Remember that this event is waiting on another event to arrive"""
assert event.follows

Expand All @@ -92,8 +93,8 @@ async def record_follower(db: PrefectDBInterface, self: Self, event: ReceivedEve

@db_injector
async def forget_follower(
db: PrefectDBInterface, self: Self, follower: ReceivedEvent
):
self, db: PrefectDBInterface, follower: ReceivedEvent
) -> None:
"""Forget that this event is waiting on another event to arrive"""
assert follower.follows

Expand All @@ -107,7 +108,7 @@ async def forget_follower(

@db_injector
async def get_followers(
db: PrefectDBInterface, self: Self, leader: ReceivedEvent
self, db: PrefectDBInterface, leader: ReceivedEvent
) -> List[ReceivedEvent]:
"""Returns events that were waiting on this leader event to arrive"""
async with db.session_context() as session:
Expand All @@ -120,7 +121,7 @@ async def get_followers(
return sorted(followers, key=lambda e: e.occurred)

@db_injector
async def get_lost_followers(db: PrefectDBInterface, self) -> List[ReceivedEvent]:
async def get_lost_followers(self, db: PrefectDBInterface) -> List[ReceivedEvent]:
"""Returns events that were waiting on a leader event that never arrived"""
earlier = pendulum.now("UTC") - PRECEDING_EVENT_LOOKBACK

Expand Down
18 changes: 5 additions & 13 deletions src/prefect/server/services/foreman.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import pendulum
import sqlalchemy as sa
from typing_extensions import Self

from prefect.server import models
from prefect.server.database.dependencies import db_injector
Expand Down Expand Up @@ -70,7 +69,7 @@ def __init__(
)

@db_injector
async def run_once(db: PrefectDBInterface, self: Self) -> None:
async def run_once(self, db: PrefectDBInterface) -> None:
"""
Iterate over workers current marked as online. Mark workers as offline
if they have an old last_heartbeat_time. Marks work pools as not ready
Expand All @@ -85,8 +84,7 @@ async def run_once(db: PrefectDBInterface, self: Self) -> None:

@db_injector
async def _mark_online_workers_without_a_recent_heartbeat_as_offline(
db: PrefectDBInterface,
self: Self,
self, db: PrefectDBInterface
) -> None:
"""
Updates the status of workers that have an old last heartbeat time
Expand Down Expand Up @@ -147,7 +145,7 @@ async def _mark_online_workers_without_a_recent_heartbeat_as_offline(
self.logger.info(f"Marked {result.rowcount} workers as offline.")

@db_injector
async def _mark_work_pools_as_not_ready(db: PrefectDBInterface, self: Self):
async def _mark_work_pools_as_not_ready(self, db: PrefectDBInterface):
"""
Marks a work pool as not ready.
Expand Down Expand Up @@ -185,10 +183,7 @@ async def _mark_work_pools_as_not_ready(db: PrefectDBInterface, self: Self):
self.logger.info(f"Marked work pool {work_pool.id} as NOT_READY.")

@db_injector
async def _mark_deployments_as_not_ready(
db: PrefectDBInterface,
self: Self,
):
async def _mark_deployments_as_not_ready(self, db: PrefectDBInterface) -> None:
"""
Marks a deployment as NOT_READY and emits a deployment status event.
Emits an event and updates any bookkeeping fields on the deployment.
Expand Down Expand Up @@ -231,10 +226,7 @@ async def _mark_deployments_as_not_ready(
)

@db_injector
async def _mark_work_queues_as_not_ready(
db: PrefectDBInterface,
self: Self,
):
async def _mark_work_queues_as_not_ready(self, db: PrefectDBInterface):
"""
Marks work queues as NOT_READY based on their last_polled field.
Expand Down
14 changes: 12 additions & 2 deletions src/prefect/server/utilities/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from .bases import (
PrefectBaseModel,
ActionBaseModel,
IDBaseModel,
ORMBaseModel,
ActionBaseModel,
PrefectBaseModel,
PrefectDescriptorBase,
get_class_fields_only,
)

__all__ = [
"ActionBaseModel",
"IDBaseModel",
"ORMBaseModel",
"PrefectBaseModel",
"PrefectDescriptorBase",
"get_class_fields_only",
]
85 changes: 47 additions & 38 deletions src/prefect/server/utilities/schemas/bases.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,59 @@
import datetime
import os
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Optional,
Set,
Type,
TypeVar,
)
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Optional, TypeVar
from uuid import UUID, uuid4

import pendulum
from pydantic import (
BaseModel,
ConfigDict,
Field,
)
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Self

from prefect.types import DateTime

if TYPE_CHECKING:
from pydantic.main import IncEx
from rich.repr import RichReprResult

T = TypeVar("T")
B = TypeVar("B", bound=BaseModel)


def get_class_fields_only(model: Type[BaseModel]) -> set:
def get_class_fields_only(model: type[BaseModel]) -> set[str]:
"""
Gets all the field names defined on the model class but not any parent classes.
Any fields that are on the parent but redefined on the subclass are included.
"""
subclass_class_fields = set(model.__annotations__.keys())
parent_class_fields = set()
# the annotations keys fit all of these criteria without further processing
return set(model.__annotations__)

for base in model.__class__.__bases__:
if issubclass(base, BaseModel):
parent_class_fields.update(base.__annotations__.keys())

return (subclass_class_fields - parent_class_fields) | (
subclass_class_fields & parent_class_fields
)
class PrefectDescriptorBase(ABC):
"""A base class for descriptor objects used with PrefectBaseModel
Pydantic needs to be told about any kind of non-standard descriptor
objects used on a model, in order for these not to be treated as a field
type instead.
This base class is registered as an ignored type with PrefectBaseModel
and any classes that inherit from it will also be ignored. This allows
such descriptors to be used as properties, methods or other bound
descriptor use cases.
"""

@abstractmethod
def __get__(
self, __instance: Optional[Any], __owner: Optional[type[Any]] = None
) -> Any:
"""Base descriptor access.
The default implementation returns itself when the instance is None,
and raises an attribute error when the instance is not not None.
"""
if __instance is not None:
raise AttributeError
return self


class PrefectBaseModel(BaseModel):
Expand All @@ -58,7 +68,7 @@ class PrefectBaseModel(BaseModel):
subtle unintentional testing errors.
"""

_reset_fields: ClassVar[Set[str]] = set()
_reset_fields: ClassVar[set[str]] = set()

model_config = ConfigDict(
ser_json_timedelta="float",
Expand All @@ -68,6 +78,7 @@ class PrefectBaseModel(BaseModel):
and os.getenv("PREFECT_TESTING_TEST_MODE", "0").lower() not in ["true", "1"]
else "forbid"
),
ignored_types=(PrefectDescriptorBase,),
)

def __eq__(self, other: Any) -> bool:
Expand All @@ -84,22 +95,20 @@ def __eq__(self, other: Any) -> bool:
else:
return copy_dict == other

def __rich_repr__(self):
def __rich_repr__(self) -> "RichReprResult":
# Display all of the fields in the model if they differ from the default value
for name, field in self.model_fields.items():
value = getattr(self, name)

# Simplify the display of some common fields
if field.annotation == UUID and value:
if isinstance(value, UUID):
value = str(value)
elif (
isinstance(field.annotation, datetime.datetime)
and name == "timestamp"
and value
):
value = pendulum.instance(value).isoformat()
elif isinstance(field.annotation, datetime.datetime) and value:
value = pendulum.instance(value).diff_for_humans()
elif isinstance(value, datetime.datetime):
value = (
pendulum.instance(value).isoformat()
if name == "timestamp"
else pendulum.instance(value).diff_for_humans()
)

yield name, value, field.get_default()

Expand All @@ -126,7 +135,7 @@ def model_dump_for_orm(
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Prefect extension to `BaseModel.model_dump`. Generate a Python dictionary
representation of the model suitable for passing to SQLAlchemy model
Expand Down Expand Up @@ -179,7 +188,7 @@ class IDBaseModel(PrefectBaseModel):
The ID is reset on copy() and not included in equality comparisons.
"""

_reset_fields: ClassVar[Set[str]] = {"id"}
_reset_fields: ClassVar[set[str]] = {"id"}
id: UUID = Field(default_factory=uuid4)


Expand All @@ -192,7 +201,7 @@ class ORMBaseModel(IDBaseModel):
equality comparisons.
"""

_reset_fields: ClassVar[Set[str]] = {"id", "created", "updated"}
_reset_fields: ClassVar[set[str]] = {"id", "created", "updated"}

model_config = ConfigDict(from_attributes=True)

Expand Down
92 changes: 69 additions & 23 deletions tests/server/database/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
AsyncPostgresConfiguration,
BaseDatabaseConfiguration,
)
from prefect.server.database.dependencies import inject_db
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.database.orm_models import (
AioSqliteORMConfiguration,
Expand Down Expand Up @@ -151,28 +150,6 @@ async def test_injecting_existing_orm_configs(ORMConfig):
assert type(db.orm) == ORMConfig


async def test_inject_db(db):
"""
Regression test for async-mangling behavior of inject_db() decorator.
Previously, when wrapping a coroutine function, the decorator returned
that function's coroutine object, instead of the coroutine function.
This worked fine in most cases because both a coroutine function and a
coroutine object can be awaited, but it broke our Pytest setup because
we were auto-marking coroutine functions as async, and any async test
wrapped by inject_db() was no longer a coroutine function, but instead
a coroutine object, so we skipped marking it.
"""

class Returner:
@inject_db
async def return_1(self, db):
return 1

assert asyncio.iscoroutinefunction(Returner().return_1)


async def test_inject_interface_class():
class TestInterface(PrefectDBInterface):
@property
Expand All @@ -182,3 +159,72 @@ def new_property(self):
with dependencies.temporary_interface_class(TestInterface):
db = dependencies.provide_database_interface()
assert isinstance(db, TestInterface)


class TestDBInject:
@pytest.fixture(autouse=True)
def _setup(self):
self.db: PrefectDBInterface = dependencies.provide_database_interface()

def test_decorated_function(self):
@dependencies.db_injector
def function_with_injected_db(
db: PrefectDBInterface, foo: int
) -> PrefectDBInterface:
"""The documentation is sublime"""
return db

assert function_with_injected_db(42) is self.db

unwrapped = function_with_injected_db.__wrapped__
assert function_with_injected_db.__doc__ == unwrapped.__doc__
function_with_injected_db.__doc__ = "Something else"
assert function_with_injected_db.__doc__ == "Something else"
assert unwrapped.__doc__ == function_with_injected_db.__doc__
del function_with_injected_db.__doc__
assert function_with_injected_db.__doc__ is None
assert unwrapped.__doc__ is function_with_injected_db.__doc__

class SomeClass:
@dependencies.db_injector
def method_with_injected_db(
self, db: PrefectDBInterface, foo: int
) -> PrefectDBInterface:
"""The documentation is sublime"""
return db

def test_decorated_method(self):
instance = self.SomeClass()
assert instance.method_with_injected_db(42) is self.db

def test_unbound_decorated_method(self):
instance = self.SomeClass()
# manually binding the unbound descriptor to an instance
bound = self.SomeClass.method_with_injected_db.__get__(instance)
assert bound(42) is self.db

def test_bound_method_attributes(self):
instance = self.SomeClass()
bound = instance.method_with_injected_db
assert bound.__self__ is instance
assert bound.__func__ is self.SomeClass.method_with_injected_db.__wrapped__

unwrapped = bound.__wrapped__
assert bound.__doc__ == unwrapped.__doc__

before = bound.__doc__
with pytest.raises(AttributeError, match="is not writable$"):
bound.__doc__ = "Something else"
with pytest.raises(AttributeError, match="is not writable$"):
del bound.__doc__
assert unwrapped.__doc__ == before

def test_decorated_coroutine_function(self):
@dependencies.db_injector
async def coroutine_with_injected_db(
db: PrefectDBInterface, foo: int
) -> PrefectDBInterface:
return db

assert asyncio.iscoroutinefunction(coroutine_with_injected_db)
assert asyncio.run(coroutine_with_injected_db(42)) is self.db
Loading

0 comments on commit b42d3a1

Please sign in to comment.