From ad54501f3fccde549c7699f8ce46c711e38accc0 Mon Sep 17 00:00:00 2001 From: michaeljs-c <72604759+michaeljs-c@users.noreply.github.com> Date: Sun, 13 Oct 2024 13:57:51 +0100 Subject: [PATCH] Make datStats endpoint dag_ids parameter optional (#42955) * Make datStats dag_id parameter optional with pagination * move pagination out of sql query * tidy --------- Co-authored-by: Michael Smith-Chandler --- .../endpoints/dag_stats_endpoint.py | 40 ++-- airflow/api_connexion/openapi/v1.yaml | 4 +- airflow/www/static/js/types/api-generated.ts | 6 +- .../endpoints/test_dag_stats_endpoint.py | 194 +++++++++++++++++- 4 files changed, 227 insertions(+), 17 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_stats_endpoint.py b/airflow/api_connexion/endpoints/dag_stats_endpoint.py index 705af10d41d9..3b6c6ab8e0df 100644 --- a/airflow/api_connexion/endpoints/dag_stats_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_stats_endpoint.py @@ -39,24 +39,40 @@ @security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session -def get_dag_stats(*, dag_ids: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dag_stats( + *, + dag_ids: str | None = None, + limit: int | None = None, + offset: int | None = None, + session: Session = NEW_SESSION, +) -> APIResponse: """Get Dag statistics.""" allowed_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user) - dags_list = set(dag_ids.split(",")) - filter_dag_ids = dags_list.intersection(allowed_dag_ids) + if dag_ids: + dags_list = set(dag_ids.split(",")) + filter_dag_ids = dags_list.intersection(allowed_dag_ids) + else: + filter_dag_ids = allowed_dag_ids + query_dag_ids = sorted(list(filter_dag_ids)) + if offset is not None: + query_dag_ids = query_dag_ids[offset:] + if limit is not None: + query_dag_ids = query_dag_ids[:limit] query = ( select(DagRun.dag_id, DagRun.state, func.count(DagRun.state)) .group_by(DagRun.dag_id, DagRun.state) - .where(DagRun.dag_id.in_(filter_dag_ids)) + .where(DagRun.dag_id.in_(query_dag_ids)) ) dag_state_stats = session.execute(query) - dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats} - dag_stats = { - dag_id: [{"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState] - for dag_id in filter_dag_ids - } - - dags = [{"dag_id": stat, "stats": dag_stats[stat]} for stat in dag_stats] - return dag_stats_collection_schema.dump({"dags": dags, "total_entries": len(dag_stats)}) + dags = [ + { + "dag_id": dag_id, + "stats": [ + {"state": state, "count": dag_state_data.get((dag_id, state), 0)} for state in DagRunState + ], + } + for dag_id in query_dag_ids + ] + return dag_stats_collection_schema.dump({"dags": dags, "total_entries": len(dags)}) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index b39d1cd955dd..e99f91639c49 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -2384,11 +2384,13 @@ paths: operationId: get_dag_stats tags: [DagStats] parameters: + - $ref: "#/components/parameters/PageLimit" + - $ref: "#/components/parameters/PageOffset" - name: dag_ids in: query schema: type: string - required: true + required: false description: | One or more DAG IDs separated by commas to filter relevant Dags. responses: diff --git a/airflow/www/static/js/types/api-generated.ts b/airflow/www/static/js/types/api-generated.ts index 15391c294243..ef45dbd3b57b 100644 --- a/airflow/www/static/js/types/api-generated.ts +++ b/airflow/www/static/js/types/api-generated.ts @@ -4962,8 +4962,12 @@ export interface operations { get_dag_stats: { parameters: { query: { + /** The numbers of items to return. */ + limit?: components["parameters"]["PageLimit"]; + /** The number of items to skip before starting to collect the result set. */ + offset?: components["parameters"]["PageOffset"]; /** One or more DAG IDs separated by commas to filter relevant Dags. */ - dag_ids: string; + dag_ids?: string; }; }; responses: { diff --git a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py index a447e2a6a4b2..fe563b944403 100644 --- a/tests/api_connexion/endpoints/test_dag_stats_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_stats_endpoint.py @@ -76,9 +76,10 @@ def _create_dag(self, dag_id): self.app.dag_bag.bag_dag(dag) return dag_instance - def test_should_respond_200(self, session): + def _create_dag_runs(self, session): self._create_dag("dag_stats_dag") self._create_dag("dag_stats_dag_2") + self._create_dag("dag_stats_dag_3") dag_1_run_1 = DagRun( dag_id="dag_stats_dag", run_id="test_dag_run_id_1", @@ -106,8 +107,20 @@ def test_should_respond_200(self, session): external_trigger=True, state="queued", ) - session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1)) + dag_3_run_1 = DagRun( + dag_id="dag_stats_dag_3", + run_id="test_dag_3_run_id_1", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state="success", + ) + session.add_all((dag_1_run_1, dag_1_run_2, dag_2_run_1, dag_3_run_1)) session.commit() + + def test_should_respond_200(self, session): + self._create_dag_runs(session) exp_payload = { "dags": [ { @@ -165,7 +178,182 @@ def test_should_respond_200(self, session): assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted( exp_payload["dags"], key=lambda d: d["dag_id"] ) - response.json["total_entries"] == 2 + assert response.json["total_entries"] == 2 + + @pytest.mark.parametrize( + "url, exp_payload", + [ + ( + "api/v1/dagStats", + { + "dags": [ + { + "dag_id": "dag_stats_dag", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 0, + }, + { + "state": DagRunState.RUNNING, + "count": 1, + }, + { + "state": DagRunState.SUCCESS, + "count": 0, + }, + { + "state": DagRunState.FAILED, + "count": 1, + }, + ], + }, + { + "dag_id": "dag_stats_dag_2", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 1, + }, + { + "state": DagRunState.RUNNING, + "count": 0, + }, + { + "state": DagRunState.SUCCESS, + "count": 0, + }, + { + "state": DagRunState.FAILED, + "count": 0, + }, + ], + }, + { + "dag_id": "dag_stats_dag_3", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 0, + }, + { + "state": DagRunState.RUNNING, + "count": 0, + }, + { + "state": DagRunState.SUCCESS, + "count": 1, + }, + { + "state": DagRunState.FAILED, + "count": 0, + }, + ], + }, + ], + "total_entries": 3, + }, + ), + ( + "api/v1/dagStats?limit=1", + { + "dags": [ + { + "dag_id": "dag_stats_dag", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 0, + }, + { + "state": DagRunState.RUNNING, + "count": 1, + }, + { + "state": DagRunState.SUCCESS, + "count": 0, + }, + { + "state": DagRunState.FAILED, + "count": 1, + }, + ], + } + ], + "total_entries": 1, + }, + ), + ( + "api/v1/dagStats?offset=2", + { + "dags": [ + { + "dag_id": "dag_stats_dag_3", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 0, + }, + { + "state": DagRunState.RUNNING, + "count": 0, + }, + { + "state": DagRunState.SUCCESS, + "count": 1, + }, + { + "state": DagRunState.FAILED, + "count": 0, + }, + ], + }, + ], + "total_entries": 1, + }, + ), + ( + "api/v1/dagStats?offset=1&limit=1", + { + "dags": [ + { + "dag_id": "dag_stats_dag_2", + "stats": [ + { + "state": DagRunState.QUEUED, + "count": 1, + }, + { + "state": DagRunState.RUNNING, + "count": 0, + }, + { + "state": DagRunState.SUCCESS, + "count": 0, + }, + { + "state": DagRunState.FAILED, + "count": 0, + }, + ], + }, + ], + "total_entries": 1, + }, + ), + ("api/v1/dagStats?offset=10&limit=1", {"dags": [], "total_entries": 0}), + ], + ) + def test_optional_dag_ids_with_limit_offset(self, url, exp_payload, session): + self._create_dag_runs(session) + + response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) + num_dags = len(exp_payload["dags"]) + assert response.status_code == 200 + assert sorted(response.json["dags"], key=lambda d: d["dag_id"]) == sorted( + exp_payload["dags"], key=lambda d: d["dag_id"] + ) + assert response.json["total_entries"] == num_dags def test_should_raises_401_unauthenticated(self): dag_ids = "dag_stats_dag,dag_stats_dag_2"