Skip to content

Commit

Permalink
Add API Endpoint - DagRuns Batch (#9556)
Browse files Browse the repository at this point in the history
Co-authored-by: Ephraim Anierobi <splendidzigy24@gmail.com>
  • Loading branch information
takunnithan and ephraimbuddy authored Jul 13, 2020
1 parent d344048 commit 5ddbbf1
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 136 deletions.
72 changes: 49 additions & 23 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
# under the License.
from connexion import NoContent
from flask import request
from marshmallow import ValidationError
from sqlalchemy import and_, func

from airflow.api_connexion.exceptions import AlreadyExists, NotFound
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema,
DAGRunCollection, dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema,
)
from airflow.models import DagModel, DagRun
from airflow.utils.session import provide_session
Expand All @@ -35,9 +36,8 @@ def delete_dag_run(dag_id, dag_run_id, session):
"""
if (
session.query(DagRun)
.filter(and_(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id))
.delete()
== 0
.filter(and_(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id))
.delete() == 0
):
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, 204
Expand Down Expand Up @@ -87,41 +87,68 @@ def get_dag_runs(
if dag_id != "~":
query = query.filter(DagRun.dag_id == dag_id)

dag_run, total_entries = _fetch_dag_runs(query, session, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte,
limit, offset)

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run,
total_entries=total_entries))


def _fetch_dag_runs(query, session, end_date_gte, end_date_lte,
execution_date_gte, execution_date_lte,
start_date_gte, start_date_lte, limit, offset):
query = _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte)
# apply offset and limit
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
total_entries = session.query(func.count(DagRun.id)).scalar()
return dag_run, total_entries


def _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte):
# filter start date
if start_date_gte:
query = query.filter(DagRun.start_date >= start_date_gte)

if start_date_lte:
query = query.filter(DagRun.start_date <= start_date_lte)

# filter execution date
if execution_date_gte:
query = query.filter(DagRun.execution_date >= execution_date_gte)

if execution_date_lte:
query = query.filter(DagRun.execution_date <= execution_date_lte)

# filter end date
if end_date_gte:
query = query.filter(DagRun.end_date >= end_date_gte)

if end_date_lte:
query = query.filter(DagRun.end_date <= end_date_lte)
return query

# apply offset and limit
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
total_entries = session.query(func.count(DagRun.id)).scalar()

return dagrun_collection_schema.dump(
DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)
)


def get_dag_runs_batch():
@provide_session
def get_dag_runs_batch(session):
"""
Get list of DAG Runs
"""
raise NotImplementedError("Not implemented yet.")
body = request.get_json()
try:
data = dagruns_batch_form_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

query = session.query(DagRun)

if data["dag_ids"]:
query = query.filter(DagRun.dag_id.in_(data["dag_ids"]))

dag_runs, total_entries = _fetch_dag_runs(query, session, data["end_date_gte"], data["end_date_lte"],
data["execution_date_gte"], data["execution_date_lte"],
data["start_date_gte"], data["start_date_lte"],
data["page_limit"], data["page_offset"])

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs,
total_entries=total_entries))


@provide_session
Expand All @@ -134,9 +161,8 @@ def post_dag_run(dag_id, session):

post_body = dagrun_schema.load(request.json, session=session)
dagrun_instance = (
session.query(DagRun)
.filter(and_(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]))
.first()
session.query(DagRun).filter(
and_(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"])).first()
)
if not dagrun_instance:
dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL.value, **post_body)
Expand Down
20 changes: 20 additions & 0 deletions airflow/api_connexion/schemas/dag_run_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,25 @@ class DAGRunCollectionSchema(Schema):
total_entries = fields.Int()


class DagRunsBatchFormSchema(Schema):
""" Schema to validate and deserialize the Form(request payload) submitted to DagRun Batch endpoint"""

class Meta:
""" Meta """
datetimeformat = 'iso'
strict = True

page_offset = fields.Int(missing=0, min=0)
page_limit = fields.Int(missing=100, min=1)
dag_ids = fields.List(fields.Str(), missing=None)
execution_date_gte = fields.DateTime(missing=None)
execution_date_lte = fields.DateTime(missing=None)
start_date_gte = fields.DateTime(missing=None)
start_date_lte = fields.DateTime(missing=None)
end_date_gte = fields.DateTime(missing=None)
end_date_lte = fields.DateTime(missing=None)


dagrun_schema = DAGRunSchema()
dagrun_collection_schema = DAGRunCollectionSchema()
dagruns_batch_form_schema = DagRunsBatchFormSchema()
Loading

0 comments on commit 5ddbbf1

Please sign in to comment.