Skip to content

Commit

Permalink
[typing] prefect.server.utilities.database (#16368)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjpieters authored Dec 14, 2024
1 parent 49a6101 commit df9d64f
Show file tree
Hide file tree
Showing 13 changed files with 742 additions and 743 deletions.
6 changes: 4 additions & 2 deletions src/prefect/server/api/run_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ async def run_history(
# estimated run times only includes positive run times (to avoid any unexpected corner cases)
"sum_estimated_run_time",
sa.func.sum(
db.greatest(0, sa.extract("epoch", runs.c.estimated_run_time))
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_run_time)
)
),
# estimated lateness is the sum of any positive start time deltas
"sum_estimated_lateness",
sa.func.sum(
db.greatest(
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_start_time_delta)
)
),
Expand Down
43 changes: 5 additions & 38 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from datetime import datetime, timezone
from datetime import datetime
from typing import List, Optional, cast

import pendulum
Expand Down Expand Up @@ -37,37 +36,6 @@ def ser_model(self) -> dict:
}


def _postgres_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
# asyncpg under Python 3.7 doesn't support timezone-aware datetimes for the EXTRACT
# function, so we will send it as a naive datetime in UTC
if sys.version_info < (3, 8):
start_datetime = start_datetime.astimezone(timezone.utc).replace(tzinfo=None)

return sa.func.floor(
(
sa.func.extract("epoch", db.TaskRun.start_time)
- sa.func.extract("epoch", start_datetime)
)
/ delta.total_seconds()
).label("bucket")


def _sqlite_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
return sa.func.floor(
(
(
sa.func.strftime("%s", db.TaskRun.start_time)
- sa.func.strftime("%s", start_datetime)
)
/ delta.total_seconds()
)
).label("bucket")


@router.post("/dashboard/counts")
async def read_dashboard_task_run_counts(
task_runs: schemas.filters.TaskRunFilter,
Expand Down Expand Up @@ -121,11 +89,10 @@ async def read_dashboard_task_run_counts(
start_time.microsecond,
start_time.timezone,
)
bucket_expression = (
_sqlite_bucket_expression(db, delta, start_datetime)
if db.dialect.name == "sqlite"
else _postgres_bucket_expression(db, delta, start_datetime)
)
bucket_expression = sa.func.floor(
sa.func.date_diff_seconds(db.TaskRun.start_time, start_datetime)
/ delta.total_seconds()
).label("bucket")

raw_counts = (
(
Expand Down
11 changes: 2 additions & 9 deletions src/prefect/server/database/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,9 @@
from typing import Dict, Hashable, Optional, Tuple

import sqlalchemy as sa

try:
from sqlalchemy import AdaptedConnection
from sqlalchemy.pool import ConnectionPoolEntry
except ImportError:
# SQLAlchemy 1.4 equivalents
from sqlalchemy.pool import _ConnectionFairy as AdaptedConnection
from sqlalchemy.pool.base import _ConnectionRecord as ConnectionPoolEntry

from sqlalchemy import AdaptedConnection
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.pool import ConnectionPoolEntry
from typing_extensions import Literal

from prefect.settings import (
Expand Down
3 changes: 0 additions & 3 deletions src/prefect/server/database/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,6 @@ def insert(self, model):
"""INSERTs a model into the database"""
return self.queries.insert(model)

def greatest(self, *values):
return self.queries.greatest(*values)

def make_timestamp_intervals(
self,
start_time: datetime.datetime,
Expand Down
56 changes: 28 additions & 28 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import datetime
import uuid
from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Hashable,
Iterable,
Optional,
Union,
cast,
)

import pendulum
import sqlalchemy as sa
from sqlalchemy import FetchedValue
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
Expand Down Expand Up @@ -46,19 +45,20 @@
WorkQueueStatus,
)
from prefect.server.utilities.database import (
CAMEL_TO_SNAKE,
JSON,
UUID,
GenerateUUID,
Pydantic,
Timestamp,
camel_to_snake,
date_diff,
interval_add,
now,
)
from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.utilities.names import generate_slug

# for 'plain JSON' columns, use the postgresql variant (which comes with an
# extra operator) and fall back to the generic JSON variant for SQLite
sa_JSON = postgresql.JSON().with_variant(sa.JSON(), "sqlite")


class Base(DeclarativeBase):
"""
Expand Down Expand Up @@ -117,7 +117,7 @@ def __tablename__(cls) -> str:
into a snake-case table name. Override by providing
an explicit `__tablename__` class property.
"""
return camel_to_snake.sub("_", cls.__name__).lower()
return CAMEL_TO_SNAKE.sub("_", cls.__name__).lower()

id: Mapped[uuid.UUID] = mapped_column(
primary_key=True,
Expand All @@ -126,17 +126,17 @@ def __tablename__(cls) -> str:
)

created: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

# onupdate is only called when statements are actually issued
# against the database. until COMMIT is issued, this column
# will not be updated
updated: Mapped[pendulum.DateTime] = mapped_column(
index=True,
server_default=now(),
server_default=sa.func.now(),
default=lambda: pendulum.now("UTC"),
onupdate=now(),
onupdate=sa.func.now(),
server_onupdate=FetchedValue(),
)

Expand Down Expand Up @@ -175,7 +175,7 @@ class FlowRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -240,7 +240,7 @@ class TaskRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -303,11 +303,11 @@ class Artifact(Base):
flow_run_id: Mapped[Optional[uuid.UUID]] = mapped_column(index=True)

type: Mapped[Optional[str]]
data: Mapped[Optional[Any]] = mapped_column(sa.JSON)
data: Mapped[Optional[Any]] = mapped_column(sa_JSON)
description: Mapped[Optional[str]]

# Suffixed with underscore as attribute name 'metadata' is reserved for the MetaData instance when using a declarative base class.
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa.JSON)
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa_JSON)

@declared_attr.directive
@classmethod
Expand Down Expand Up @@ -342,9 +342,9 @@ class ArtifactCollection(Base):
flow_run_id: Mapped[Optional[uuid.UUID]]

type: Mapped[Optional[str]]
data: Mapped[Optional[Any]] = mapped_column(sa.JSON)
data: Mapped[Optional[Any]] = mapped_column(sa_JSON)
description: Mapped[Optional[str]]
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa.JSON)
metadata_: Mapped[Optional[dict[str, str]]] = mapped_column(sa_JSON)

__table_args__: Any = (
sa.UniqueConstraint("key"),
Expand Down Expand Up @@ -419,9 +419,9 @@ def _estimated_run_time_expression(cls) -> sa.Label[datetime.timedelta]:
sa.case(
(
cls.state_type == schemas.states.StateType.RUNNING,
interval_add(
sa.func.interval_add(
cls.total_run_time,
date_diff(now(), cls.state_timestamp),
sa.func.date_diff(sa.func.now(), cls.state_timestamp),
),
),
else_=cls.total_run_time,
Expand Down Expand Up @@ -464,15 +464,15 @@ def _estimated_start_time_delta_expression(
return sa.case(
(
cls.start_time > cls.expected_start_time,
date_diff(cls.start_time, cls.expected_start_time),
sa.func.date_diff(cls.start_time, cls.expected_start_time),
),
(
sa.and_(
cls.start_time.is_(None),
cls.state_type.not_in(schemas.states.TERMINAL_STATES),
cls.expected_start_time < now(),
cls.expected_start_time < sa.func.now(),
),
date_diff(now(), cls.expected_start_time),
sa.func.date_diff(sa.func.now(), cls.expected_start_time),
),
else_=datetime.timedelta(0),
)
Expand Down Expand Up @@ -1165,7 +1165,7 @@ class Worker(Base):

name: Mapped[str]
last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
heartbeat_interval_seconds: Mapped[Optional[int]]

Expand Down Expand Up @@ -1195,7 +1195,7 @@ class Agent(Base):
)

last_activity_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand Down Expand Up @@ -1277,11 +1277,11 @@ class Automation(Base):
@classmethod
def sort_expression(cls, value: AutomationSort) -> sa.ColumnExpressionArgument[Any]:
"""Return an expression used to sort Automations"""
sort_mapping = {
sort_mapping: dict[AutomationSort, sa.ColumnExpressionArgument[Any]] = {
AutomationSort.CREATED_DESC: cls.created.desc(),
AutomationSort.UPDATED_DESC: cls.updated.desc(),
AutomationSort.NAME_ASC: cast(sa.Column, cls.name).asc(),
AutomationSort.NAME_DESC: cast(sa.Column, cls.name).desc(),
AutomationSort.NAME_ASC: cls.name.asc(),
AutomationSort.NAME_DESC: cls.name.desc(),
}
return sort_mapping[value]

Expand Down Expand Up @@ -1439,7 +1439,7 @@ def __tablename__(cls) -> str:
occurred: Mapped[pendulum.DateTime]
resource_id: Mapped[str] = mapped_column(sa.Text())
resource_role: Mapped[str] = mapped_column(sa.Text())
resource: Mapped[dict[str, Any]] = mapped_column(sa.JSON())
resource: Mapped[dict[str, Any]] = mapped_column(sa_JSON)
event_id: Mapped[uuid.UUID]


Expand Down
40 changes: 13 additions & 27 deletions src/prefect/server/database/query_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from prefect.server.exceptions import FlowRunGraphTooLarge, ObjectNotFoundError
from prefect.server.schemas.graph import Edge, Graph, GraphArtifact, GraphState, Node
from prefect.server.utilities.database import UUID as UUIDTypeDecorator
from prefect.server.utilities.database import Timestamp, json_has_any_key
from prefect.server.utilities.database import Timestamp

if TYPE_CHECKING:
from prefect.server.database.interface import PrefectDBInterface
Expand Down Expand Up @@ -61,14 +61,6 @@ def _unique_key(self) -> Tuple[Hashable, ...]:
def insert(self, obj) -> Union[postgresql.Insert, sqlite.Insert]:
"""dialect-specific insert statement"""

@abstractmethod
def greatest(self, *values):
"""dialect-specific SqlAlchemy binding"""

@abstractmethod
def least(self, *values):
"""dialect-specific SqlAlchemy binding"""

# --- dialect-specific JSON handling

@abstractproperty
Expand Down Expand Up @@ -121,6 +113,13 @@ async def queue_flow_run_notifications(
db: "PrefectDBInterface",
):
"""Database-specific implementation of queueing notifications for a flow run"""

def as_array(elems: Sequence[str]) -> sa.ColumnElement[Sequence[str]]:
return sa.cast(postgresql.array(elems), type_=postgresql.ARRAY(sa.String()))

if TYPE_CHECKING:
assert flow_run.state_name is not None

# insert a <policy, state> pair into the notification queue
stmt = db.insert(orm_models.FlowRunNotificationQueue).from_select(
[
Expand All @@ -140,16 +139,15 @@ async def queue_flow_run_notifications(
# the policy state names aren't set or match the current state name
sa.or_(
orm_models.FlowRunNotificationPolicy.state_names == [],
json_has_any_key(
orm_models.FlowRunNotificationPolicy.state_names,
[flow_run.state_name],
orm_models.FlowRunNotificationPolicy.state_names.has_any(
as_array([flow_run.state_name])
),
),
# the policy tags aren't set, or the tags match the flow run tags
sa.or_(
orm_models.FlowRunNotificationPolicy.tags == [],
json_has_any_key(
orm_models.FlowRunNotificationPolicy.tags, flow_run.tags
orm_models.FlowRunNotificationPolicy.tags.has_any(
as_array(flow_run.tags)
),
),
)
Expand Down Expand Up @@ -179,7 +177,7 @@ def get_scheduled_flow_runs_from_work_queues(
concurrency_queues = (
sa.select(
orm_models.WorkQueue.id,
self.greatest(
sa.func.greatest(
0,
orm_models.WorkQueue.concurrency_limit
- sa.func.count(orm_models.FlowRun.id),
Expand Down Expand Up @@ -628,12 +626,6 @@ class AsyncPostgresQueryComponents(BaseQueryComponents):
def insert(self, obj) -> postgresql.Insert:
return postgresql.insert(obj)

def greatest(self, *values):
return sa.func.greatest(*values)

def least(self, *values):
return sa.func.least(*values)

# --- Postgres-specific JSON handling

@property
Expand Down Expand Up @@ -984,12 +976,6 @@ class AioSqliteQueryComponents(BaseQueryComponents):
def insert(self, obj) -> sqlite.Insert:
return sqlite.insert(obj)

def greatest(self, *values):
return sa.func.max(*values)

def least(self, *values):
return sa.func.min(*values)

# --- Sqlite-specific JSON handling

@property
Expand Down
Loading

0 comments on commit df9d64f

Please sign in to comment.