Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more sort and filters to get dags endpoint #42462

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions airflow/api_fastapi/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@

from typing import TYPE_CHECKING

from sqlalchemy import func, select

from airflow.models.dagrun import DagRun
from airflow.utils.session import create_session

if TYPE_CHECKING:
Expand Down Expand Up @@ -52,3 +55,11 @@ def apply_filters_to_select(base_select: Select, filters: list[BaseParam]) -> Se
select = filter.to_orm(select)

return select


latest_dag_run_per_dag_id_cte = (
select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date"))
.where()
.group_by(DagRun.dag_id)
.cte()
)
24 changes: 24 additions & 0 deletions airflow/api_fastapi/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ paths:
- type: boolean
- type: 'null'
title: Paused
- name: last_dag_run_state
in: query
required: false
schema:
anyOf:
- $ref: '#/components/schemas/DagRunState'
- type: 'null'
title: Last Dag Run State
- name: order_by
in: query
required: false
Expand Down Expand Up @@ -347,6 +355,22 @@ components:
- file_token
title: DAGResponse
description: DAG serializer for responses.
DagRunState:
type: string
enum:
- queued
- running
- success
- failed
title: DagRunState
description: 'All possible states that a DagRun can be in.


These are "shared" with TaskInstanceState in some parts of the code,

so please ensure that their values always match the ones with the

same name in TaskInstanceState.'
DagTagPydantic:
properties:
name:
Expand Down
84 changes: 52 additions & 32 deletions airflow/api_fastapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generic, List, TypeVar

from fastapi import Depends, HTTPException, Query
from sqlalchemy import case, or_
from typing_extensions import Annotated, Self

from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.utils.state import DagRunState

if TYPE_CHECKING:
from sqlalchemy.sql import ColumnElement, Select
Expand All @@ -43,14 +45,14 @@ def __init__(self) -> None:
def to_orm(self, select: Select) -> Select:
pass

@abstractmethod
def __call__(self, *args: Any, **kwarg: Any) -> BaseParam:
pass

def set_value(self, value: T) -> Self:
def set_value(self, value: T | None) -> Self:
self.value = value
return self

@abstractmethod
def depends(self, *args: Any, **kwargs: Any) -> Self:
pass


class _LimitFilter(BaseParam[int]):
"""Filter on the limit."""
Expand All @@ -61,7 +63,7 @@ def to_orm(self, select: Select) -> Select:

return select.limit(self.value)

def __call__(self, limit: int = 100) -> _LimitFilter:
def depends(self, limit: int = 100) -> _LimitFilter:
return self.set_value(limit)


Expand All @@ -73,19 +75,19 @@ def to_orm(self, select: Select) -> Select:
return select
return select.offset(self.value)

def __call__(self, offset: int = 0) -> _OffsetFilter:
def depends(self, offset: int = 0) -> _OffsetFilter:
return self.set_value(offset)


class _PausedFilter(BaseParam[Union[bool, None]]):
class _PausedFilter(BaseParam[bool]):
"""Filter on is_paused."""

def to_orm(self, select: Select) -> Select:
if self.value is None:
return select
return select.where(DagModel.is_paused == self.value)

def __call__(self, paused: bool | None = Query(default=None)) -> _PausedFilter:
def depends(self, paused: bool | None = None) -> _PausedFilter:
return self.set_value(paused)


Expand All @@ -97,11 +99,11 @@ def to_orm(self, select: Select) -> Select:
return select.where(DagModel.is_active == self.value)
return select

def __call__(self, only_active: bool = Query(default=True)) -> _OnlyActiveFilter:
def depends(self, only_active: bool = True) -> _OnlyActiveFilter:
return self.set_value(only_active)


class _SearchParam(BaseParam[Union[str, None]]):
class _SearchParam(BaseParam[str]):
"""Search on attribute."""

def __init__(self, attribute: ColumnElement) -> None:
Expand All @@ -120,7 +122,7 @@ class _DagIdPatternSearch(_SearchParam):
def __init__(self) -> None:
super().__init__(DagModel.dag_id)

def __call__(self, dag_id_pattern: str | None = Query(default=None)) -> _DagIdPatternSearch:
def depends(self, dag_id_pattern: str | None = None) -> _DagIdPatternSearch:
return self.set_value(dag_id_pattern)


Expand All @@ -130,15 +132,18 @@ class _DagDisplayNamePatternSearch(_SearchParam):
def __init__(self) -> None:
super().__init__(DagModel.dag_display_name)

def __call__(
self, dag_display_name_pattern: str | None = Query(default=None)
) -> _DagDisplayNamePatternSearch:
def depends(self, dag_display_name_pattern: str | None = None) -> _DagDisplayNamePatternSearch:
return self.set_value(dag_display_name_pattern)


class SortParam(BaseParam[Union[str]]):
class SortParam(BaseParam[str]):
"""Order result by the attribute."""

attr_mapping = {
"last_run_state": DagRun.state,
"last_run_start_date": DagRun.start_date,
}

def __init__(self, allowed_attrs: list[str]) -> None:
super().__init__()
self.allowed_attrs = allowed_attrs
Expand All @@ -155,17 +160,17 @@ def to_orm(self, select: Select) -> Select:
f"the attribute does not exist on the model",
)

column = getattr(DagModel, lstriped_orderby)
column = self.attr_mapping.get(lstriped_orderby, None) or getattr(DagModel, lstriped_orderby)

# MySQL does not support `nullslast`, and True/False ordering depends on the
# database implementation
# database implementation.
nullscheck = case((column.isnot(None), 0), else_=1)
if self.value[0] == "-":
return select.order_by(nullscheck, column.desc(), DagModel.dag_id)
return select.order_by(nullscheck, column.desc(), DagModel.dag_id.desc())
else:
return select.order_by(nullscheck, column.asc(), DagModel.dag_id)
return select.order_by(nullscheck, column.asc(), DagModel.dag_id.asc())

def __call__(self, order_by: str = Query(default="dag_id")) -> SortParam:
def depends(self, order_by: str = "dag_id") -> SortParam:
return self.set_value(order_by)


Expand All @@ -179,7 +184,7 @@ def to_orm(self, select: Select) -> Select:
conditions = [DagModel.tags.any(DagTag.name == tag) for tag in self.value]
return select.where(or_(*conditions))

def __call__(self, tags: list[str] = Query(default_factory=list)) -> _TagsFilter:
def depends(self, tags: list[str] = Query(default_factory=list)) -> _TagsFilter:
return self.set_value(tags)


Expand All @@ -193,17 +198,32 @@ def to_orm(self, select: Select) -> Select:
conditions = [DagModel.owners.ilike(f"%{owner}%") for owner in self.value]
return select.where(or_(*conditions))

def __call__(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFilter:
def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFilter:
return self.set_value(owners)


QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter())]
QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter())]
QueryPausedFilter = Annotated[_PausedFilter, Depends(_PausedFilter())]
QueryOnlyActiveFilter = Annotated[_OnlyActiveFilter, Depends(_OnlyActiveFilter())]
QueryDagIdPatternSearch = Annotated[_DagIdPatternSearch, Depends(_DagIdPatternSearch())]
class _LastDagRunStateFilter(BaseParam[DagRunState]):
"""Filter on the state of the latest DagRun."""

def to_orm(self, select: Select) -> Select:
if self.value is None:
return select
return select.where(DagRun.state == self.value)

def depends(self, last_dag_run_state: DagRunState | None = None) -> _LastDagRunStateFilter:
return self.set_value(last_dag_run_state)


# DAG
QueryLimit = Annotated[_LimitFilter, Depends(_LimitFilter().depends)]
QueryOffset = Annotated[_OffsetFilter, Depends(_OffsetFilter().depends)]
QueryPausedFilter = Annotated[_PausedFilter, Depends(_PausedFilter().depends)]
QueryOnlyActiveFilter = Annotated[_OnlyActiveFilter, Depends(_OnlyActiveFilter().depends)]
QueryDagIdPatternSearch = Annotated[_DagIdPatternSearch, Depends(_DagIdPatternSearch().depends)]
QueryDagDisplayNamePatternSearch = Annotated[
_DagDisplayNamePatternSearch, Depends(_DagDisplayNamePatternSearch())
_DagDisplayNamePatternSearch, Depends(_DagDisplayNamePatternSearch().depends)
]
QueryTagsFilter = Annotated[_TagsFilter, Depends(_TagsFilter())]
QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter())]
QueryTagsFilter = Annotated[_TagsFilter, Depends(_TagsFilter().depends)]
QueryOwnersFilter = Annotated[_OwnersFilter, Depends(_OwnersFilter().depends)]
# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
32 changes: 28 additions & 4 deletions airflow/api_fastapi/views/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from sqlalchemy.orm import Session
from typing_extensions import Annotated

from airflow.api_fastapi.db import apply_filters_to_select, get_session
from airflow.api_fastapi.db import apply_filters_to_select, get_session, latest_dag_run_per_dag_id_cte
from airflow.api_fastapi.parameters import (
QueryDagDisplayNamePatternSearch,
QueryDagIdPatternSearch,
QueryLastDagRunStateFilter,
QueryLimit,
QueryOffset,
QueryOnlyActiveFilter,
Expand All @@ -36,6 +37,7 @@
)
from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, DAGPatchBody, DAGResponse
from airflow.models import DagModel
from airflow.models.dagrun import DagRun
from airflow.utils.db import get_query_count

dags_router = APIRouter(tags=["DAG"])
Expand All @@ -51,14 +53,36 @@ async def get_dags(
dag_display_name_pattern: QueryDagDisplayNamePatternSearch,
only_active: QueryOnlyActiveFilter,
paused: QueryPausedFilter,
order_by: Annotated[SortParam, Depends(SortParam(["dag_id", "dag_display_name", "next_dagrun"]))],
last_dag_run_state: QueryLastDagRunStateFilter,
order_by: Annotated[
SortParam,
Depends(
SortParam(
["dag_id", "dag_display_name", "next_dagrun", "last_run_state", "last_run_start_date"]
).depends
),
],
session: Annotated[Session, Depends(get_session)],
) -> DAGCollectionResponse:
"""Get all DAGs."""
dags_query = select(DagModel)
dags_query = (
select(DagModel)
.join(
latest_dag_run_per_dag_id_cte,
DagModel.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id,
isouter=True,
)
.join(
DagRun,
DagRun.start_date == latest_dag_run_per_dag_id_cte.c.start_date
and DagRun.dag_id == latest_dag_run_per_dag_id_cte.c.dag_id,
isouter=True,
)
)

dags_query = apply_filters_to_select(
dags_query, [only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners]
dags_query,
[only_active, paused, dag_id_pattern, dag_display_name_pattern, tags, owners, last_dag_run_state],
)

# TODO: Re-enable when permissions are handled.
Expand Down
4 changes: 4 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { UseQueryResult } from "@tanstack/react-query";

import { DagService, DatasetService } from "../requests/services.gen";
import { DagRunState } from "../requests/types.gen";

export type DatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGetDefaultResponse =
Awaited<
Expand Down Expand Up @@ -37,6 +38,7 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = (
{
dagDisplayNamePattern,
dagIdPattern,
lastDagRunState,
limit,
offset,
onlyActive,
Expand All @@ -47,6 +49,7 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = (
}: {
dagDisplayNamePattern?: string;
dagIdPattern?: string;
lastDagRunState?: DagRunState;
limit?: number;
offset?: number;
onlyActive?: boolean;
Expand All @@ -62,6 +65,7 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = (
{
dagDisplayNamePattern,
dagIdPattern,
lastDagRunState,
limit,
offset,
onlyActive,
Expand Down
6 changes: 6 additions & 0 deletions airflow/ui/openapi-gen/queries/prefetch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import { type QueryClient } from "@tanstack/react-query";

import { DagService, DatasetService } from "../requests/services.gen";
import { DagRunState } from "../requests/types.gen";
import * as Common from "./common";

/**
Expand Down Expand Up @@ -40,6 +41,7 @@ export const prefetchUseDatasetServiceNextRunDatasetsUiNextRunDatasetsDagIdGet =
* @param data.dagDisplayNamePattern
* @param data.onlyActive
* @param data.paused
* @param data.lastDagRunState
* @param data.orderBy
* @returns DAGCollectionResponse Successful Response
* @throws ApiError
Expand All @@ -49,6 +51,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = (
{
dagDisplayNamePattern,
dagIdPattern,
lastDagRunState,
limit,
offset,
onlyActive,
Expand All @@ -59,6 +62,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = (
}: {
dagDisplayNamePattern?: string;
dagIdPattern?: string;
lastDagRunState?: DagRunState;
limit?: number;
offset?: number;
onlyActive?: boolean;
Expand All @@ -72,6 +76,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = (
queryKey: Common.UseDagServiceGetDagsPublicDagsGetKeyFn({
dagDisplayNamePattern,
dagIdPattern,
lastDagRunState,
limit,
offset,
onlyActive,
Expand All @@ -84,6 +89,7 @@ export const prefetchUseDagServiceGetDagsPublicDagsGet = (
DagService.getDagsPublicDagsGet({
dagDisplayNamePattern,
dagIdPattern,
lastDagRunState,
limit,
offset,
onlyActive,
Expand Down
Loading