Skip to content

Commit

Permalink
Make datStats dag_id parameter optional with pagination
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljs-c committed Oct 10, 2024
1 parent e2f2653 commit 41bea17
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 17 deletions.
54 changes: 42 additions & 12 deletions airflow/api_connexion/endpoints/dag_stats_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,54 @@

@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_stats(*, dag_ids: str, session: Session = NEW_SESSION) -> APIResponse:
def get_dag_stats(
*,
dag_ids: str | None = None,
limit: int | None = None,
offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get Dag statistics."""
allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
dags_list = set(dag_ids.split(","))
filter_dag_ids = dags_list.intersection(allowed_dag_ids)
if dag_ids:
dags_list = set(dag_ids.split(","))
filter_dag_ids = dags_list.intersection(allowed_dag_ids)
else:
filter_dag_ids = allowed_dag_ids

query = (
select(DagRun.dag_id, DagRun.state, func.count(DagRun.state))
select(DagRun.dag_id, DagRun.state, func.count(DagRun.state).label("count"))
.group_by(DagRun.dag_id, DagRun.state)
.where(DagRun.dag_id.in_(filter_dag_ids))
)
dag_state_stats = session.execute(query)
if limit or offset:
query = query.subquery()
ranked_query = select(
query.c.dag_id,
query.c.state,
query.c.count,
func.dense_rank().over(order_by=query.c.dag_id).label("rank"),
).subquery()
paginated_query = select(ranked_query.c.dag_id, ranked_query.c.state, ranked_query.c.count)
if offset:
paginated_query = paginated_query.where(ranked_query.c.rank > offset)
if limit:
paginated_query = paginated_query.where(
ranked_query.c.rank <= limit + (offset if offset is not None else 0)
)
dag_state_stats = session.execute(paginated_query)
else:
dag_state_stats = session.execute(query)

dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats}
dag_stats = {
dag_id: [{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState]
for dag_id in filter_dag_ids
}

dags = [{"dag_id": stat, "stats": dag_stats[stat]} for stat in dag_stats]
return dag_stats_collection_schema.dump({"dags": dags, "total_entries": len(dag_stats)})
dag_ids = {dag[0] for dag in dag_state_data.keys()}
dag_stats = [
{
"dag_id": dag_id,
"stats": [
{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState
],
}
for dag_id in dag_ids
]
return dag_stats_collection_schema.dump({"dags": dag_stats, "total_entries": len(dag_stats)})
4 changes: 3 additions & 1 deletion airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2429,11 +2429,13 @@ paths:
operationId: get_dag_stats
tags: [DagStats]
parameters:
- $ref: "#/components/parameters/PageLimit"
- $ref: "#/components/parameters/PageOffset"
- name: dag_ids
in: query
schema:
type: string
required: true
required: false
description: |
One or more DAG IDs separated by commas to filter relevant Dags.
responses:
Expand Down
6 changes: 5 additions & 1 deletion airflow/www/static/js/types/api-generated.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4973,8 +4973,12 @@ export interface operations {
get_dag_stats: {
parameters: {
query: {
/** The numbers of items to return. */
limit?: components["parameters"]["PageLimit"];
/** The number of items to skip before starting to collect the result set. */
offset?: components["parameters"]["PageOffset"];
/** One or more DAG IDs separated by commas to filter relevant Dags. */
dag_ids: string;
dag_ids?: string;
};
};
responses: {
Expand Down
193 changes: 190 additions & 3 deletions tests/api_connexion/endpoints/test_dag_stats_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def _create_dag(self, dag_id):
self.app.dag_bag.bag_dag(dag)
return dag_instance

def test_should_respond_200(self, session):
def _create_dag_runs(self, session):
self._create_dag("dag_stats_dag")
self._create_dag("dag_stats_dag_2")
self._create_dag("dag_stats_dag_3")
dag_1_run_1 = DagRun(
dag_id="dag_stats_dag",
run_id="test_dag_run_id_1",
Expand Down Expand Up @@ -105,8 +106,20 @@ def test_should_respond_200(self, session):
external_trigger=True,
state="queued",
)
session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1))
dag_3_run_1 = DagRun(
dag_id="dag_stats_dag_3",
run_id="test_dag_3_run_id_1",
run_type=DagRunType.MANUAL,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
external_trigger=True,
state="success",
)
session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1, dag_3_run_1))
session.commit()

def test_should_respond_200(self, session):
self._create_dag_runs(session)
exp_payload = {
"dags": [
{
Expand Down Expand Up @@ -164,7 +177,181 @@ def test_should_respond_200(self, session):
assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted(
exp_payload["dags"], key=lambda d: d["dag_id"]
)
response.json["total_entries"] == 2
assert response.json["total_entries"] == 2

@pytest.mark.parametrize(
"url, exp_payload",
[
(
"api/v1/dagStats",
{
"dags": [
{
"dag_id": "dag_stats_dag",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 1,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 1,
},
],
},
{
"dag_id": "dag_stats_dag_2",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 1,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
{
"dag_id": "dag_stats_dag_3",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 1,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 3,
},
),
(
"api/v1/dagStats?limit=1",
{
"dags": [
{
"dag_id": "dag_stats_dag",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 1,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 1,
},
],
}
],
"total_entries": 1,
},
),
(
"api/v1/dagStats?offset=2",
{
"dags": [
{
"dag_id": "dag_stats_dag_3",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 1,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 1,
},
),
(
"api/v1/dagStats?offset=1&limit=1",
{
"dags": [
{
"dag_id": "dag_stats_dag_2",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 1,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 1,
},
),
],
)
def test_optional_dag_ids_with_limit_offset(self, url, exp_payload, session):
self._create_dag_runs(session)

response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
num_dags = len(exp_payload["dags"])
assert response.status_code == 200
assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted(
exp_payload["dags"], key=lambda d: d["dag_id"]
)
assert response.json["total_entries"] == num_dags

def test_should_raises_401_unauthenticated(self):
dag_ids = "dag_stats_dag,dag_stats_dag_2"
Expand Down

0 comments on commit 41bea17

Please sign in to comment.