Skip to content

Commit

Permalink
Make datStats endpoint dag_ids parameter optional (#42955)
Browse files Browse the repository at this point in the history
* Make datStats dag_id parameter optional with pagination

* move pagination out of sql query

* tidy

---------

Co-authored-by: Michael Smith-Chandler <mjsmithandler@gmail.com>
  • Loading branch information
michaeljs-c and Michael Smith-Chandler authored Oct 13, 2024
1 parent b92c66d commit ad54501
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 17 deletions.
40 changes: 28 additions & 12 deletions airflow/api_connexion/endpoints/dag_stats_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,40 @@

@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_stats(*, dag_ids: str, session: Session = NEW_SESSION) -> APIResponse:
def get_dag_stats(
*,
dag_ids: str | None = None,
limit: int | None = None,
offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get Dag statistics."""
allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
dags_list = set(dag_ids.split(","))
filter_dag_ids = dags_list.intersection(allowed_dag_ids)
if dag_ids:
dags_list = set(dag_ids.split(","))
filter_dag_ids = dags_list.intersection(allowed_dag_ids)
else:
filter_dag_ids = allowed_dag_ids
query_dag_ids = sorted(list(filter_dag_ids))
if offset is not None:
query_dag_ids = query_dag_ids[offset:]
if limit is not None:
query_dag_ids = query_dag_ids[:limit]

query = (
select(DagRun.dag_id, DagRun.state, func.count(DagRun.state))
.group_by(DagRun.dag_id, DagRun.state)
.where(DagRun.dag_id.in_(filter_dag_ids))
.where(DagRun.dag_id.in_(query_dag_ids))
)
dag_state_stats = session.execute(query)

dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats}
dag_stats = {
dag_id: [{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState]
for dag_id in filter_dag_ids
}

dags = [{"dag_id": stat, "stats": dag_stats[stat]} for stat in dag_stats]
return dag_stats_collection_schema.dump({"dags": dags, "total_entries": len(dag_stats)})
dags = [
{
"dag_id": dag_id,
"stats": [
{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState
],
}
for dag_id in query_dag_ids
]
return dag_stats_collection_schema.dump({"dags": dags, "total_entries": len(dags)})
4 changes: 3 additions & 1 deletion airflow/api_connexion/openapi/v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2384,11 +2384,13 @@ paths:
operationId: get_dag_stats
tags: [DagStats]
parameters:
- $ref: "#/components/parameters/PageLimit"
- $ref: "#/components/parameters/PageOffset"
- name: dag_ids
in: query
schema:
type: string
required: true
required: false
description: |
One or more DAG IDs separated by commas to filter relevant Dags.
responses:
Expand Down
6 changes: 5 additions & 1 deletion airflow/www/static/js/types/api-generated.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4962,8 +4962,12 @@ export interface operations {
get_dag_stats: {
parameters: {
query: {
/** The numbers of items to return. */
limit?: components["parameters"]["PageLimit"];
/** The number of items to skip before starting to collect the result set. */
offset?: components["parameters"]["PageOffset"];
/** One or more DAG IDs separated by commas to filter relevant Dags. */
dag_ids: string;
dag_ids?: string;
};
};
responses: {
Expand Down
194 changes: 191 additions & 3 deletions tests/api_connexion/endpoints/test_dag_stats_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ def _create_dag(self, dag_id):
self.app.dag_bag.bag_dag(dag)
return dag_instance

def test_should_respond_200(self, session):
def _create_dag_runs(self, session):
self._create_dag("dag_stats_dag")
self._create_dag("dag_stats_dag_2")
self._create_dag("dag_stats_dag_3")
dag_1_run_1 = DagRun(
dag_id="dag_stats_dag",
run_id="test_dag_run_id_1",
Expand Down Expand Up @@ -106,8 +107,20 @@ def test_should_respond_200(self, session):
external_trigger=True,
state="queued",
)
session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1))
dag_3_run_1 = DagRun(
dag_id="dag_stats_dag_3",
run_id="test_dag_3_run_id_1",
run_type=DagRunType.MANUAL,
execution_date=timezone.parse(self.default_time),
start_date=timezone.parse(self.default_time),
external_trigger=True,
state="success",
)
session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1, dag_3_run_1))
session.commit()

def test_should_respond_200(self, session):
self._create_dag_runs(session)
exp_payload = {
"dags": [
{
Expand Down Expand Up @@ -165,7 +178,182 @@ def test_should_respond_200(self, session):
assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted(
exp_payload["dags"], key=lambda d: d["dag_id"]
)
response.json["total_entries"] == 2
assert response.json["total_entries"] == 2

@pytest.mark.parametrize(
"url, exp_payload",
[
(
"api/v1/dagStats",
{
"dags": [
{
"dag_id": "dag_stats_dag",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 1,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 1,
},
],
},
{
"dag_id": "dag_stats_dag_2",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 1,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
{
"dag_id": "dag_stats_dag_3",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 1,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 3,
},
),
(
"api/v1/dagStats?limit=1",
{
"dags": [
{
"dag_id": "dag_stats_dag",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 1,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 1,
},
],
}
],
"total_entries": 1,
},
),
(
"api/v1/dagStats?offset=2",
{
"dags": [
{
"dag_id": "dag_stats_dag_3",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 0,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 1,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 1,
},
),
(
"api/v1/dagStats?offset=1&limit=1",
{
"dags": [
{
"dag_id": "dag_stats_dag_2",
"stats": [
{
"state": DagRunState.QUEUED,
"count": 1,
},
{
"state": DagRunState.RUNNING,
"count": 0,
},
{
"state": DagRunState.SUCCESS,
"count": 0,
},
{
"state": DagRunState.FAILED,
"count": 0,
},
],
},
],
"total_entries": 1,
},
),
("api/v1/dagStats?offset=10&limit=1", {"dags": [], "total_entries": 0}),
],
)
def test_optional_dag_ids_with_limit_offset(self, url, exp_payload, session):
self._create_dag_runs(session)

response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
num_dags = len(exp_payload["dags"])
assert response.status_code == 200
assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted(
exp_payload["dags"], key=lambda d: d["dag_id"]
)
assert response.json["total_entries"] == num_dags

def test_should_raises_401_unauthenticated(self):
dag_ids = "dag_stats_dag,dag_stats_dag_2"
Expand Down

0 comments on commit ad54501

Please sign in to comment.