Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Airflow Datasets #171

Merged
merged 5 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion dagfactory/dagbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@
XComArg = None
# pylint: disable=ungrouped-imports,invalid-name

if version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0"):
from airflow.datasets import Dataset
else:
Dataset = None


# these are params only used in the DAG factory, not in the tasks
SYSTEM_PARAMS: List[str] = ["operator", "dependencies", "task_group_name"]

Expand Down Expand Up @@ -588,6 +594,25 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator:
):
task_params.update(partial_kwargs)

if utils.check_dict_key(task_params, "outlets") and version.parse(
AIRFLOW_VERSION
) >= version.parse("2.4.0"):
if utils.check_dict_key(
task_params["outlets"], "file"
) and utils.check_dict_key(task_params["outlets"], "datasets"):
file = task_params["outlets"]["file"]
datasets_filter = task_params["outlets"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(
file, datasets_filter
)

del task_params["outlets"]["file"]
del task_params["outlets"]["datasets"]
else:
datasets_uri = task_params["outlets"]

task_params["outlets"] = [Dataset(uri) for uri in datasets_uri]

task: Union[BaseOperator, MappedOperator] = (
operator_obj(**task_params)
if not expand_kwargs
Expand Down Expand Up @@ -696,7 +721,9 @@ def build(self) -> Dict[str, Union[str, DAG]]:

dag_kwargs["dag_id"] = dag_params["dag_id"]

if not dag_params.get("timetable"):
if not dag_params.get("timetable") and not utils.check_dict_key(
dag_params, "schedule"
):
dag_kwargs["schedule_interval"] = dag_params.get(
"schedule_interval", timedelta(days=1)
)
Expand Down Expand Up @@ -765,6 +792,25 @@ def build(self) -> Dict[str, Union[str, DAG]]:
"is_paused_upon_creation", None
)

if (
utils.check_dict_key(dag_params, "schedule")
and not utils.check_dict_key(dag_params, "schedule_interval")
and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0")
):
if utils.check_dict_key(
dag_params["schedule"], "file"
) and utils.check_dict_key(dag_params["schedule"], "datasets"):
file = dag_params["schedule"]["file"]
datasets_filter = dag_params["schedule"]["datasets"]
datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter)

del dag_params["schedule"]["file"]
del dag_params["schedule"]["datasets"]
else:
datasets_uri = dag_params["schedule"]

dag_kwargs["schedule"] = [Dataset(uri) for uri in datasets_uri]

dag_kwargs["params"] = dag_params.get("params", None)

dag: DAG = DAG(**dag_kwargs)
Expand Down
29 changes: 29 additions & 0 deletions dagfactory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datetime import date, datetime, timedelta
from pathlib import Path
from typing import Any, AnyStr, Dict, Match, Optional, Pattern, Union, List, Tuple
import yaml

import pendulum

Expand Down Expand Up @@ -263,3 +264,31 @@ def is_partial_duplicated(
"Duplicated partial kwarg! It's already in task_params."
)
return False


def get_datasets_uri_yaml_file(file_path: str, datasets_filter: str) -> List[str]:
"""
Retrieves the URIs of datasets from a YAML file based on a given filter.

:param file_path: The path to the YAML file.
:type file_path: str
:param datasets_filter: A list of dataset names to filter the results.
:type datasets_filter: List[str]
:return: A list of dataset URIs that match the filter.
:rtype: List[str]
"""
try:
with open(file_path, "r", encoding="UTF-8") as file:
data = yaml.safe_load(file)

datasets = data.get("datasets", [])
datasets_result_uri = [
dataset["uri"]
for dataset in datasets
if dataset["name"] in datasets_filter and "uri" in dataset
]
return datasets_result_uri
except FileNotFoundError as err:
raise FileNotFoundError(f"Error: File '{file_path}' not found.") from err
except yaml.YAMLError as error:
raise error
7 changes: 7 additions & 0 deletions examples/datasets/example_config_datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
datasets:
- name: dataset_custom_1
uri: s3://bucket-cjmm/raw/dataset_custom_1
- name: dataset_custom_2
uri: s3://bucket-cjmm/raw/dataset_custom_2
- name: dataset_custom_3
uri: s3://bucket-cjmm/raw/dataset_custom_3
10 changes: 10 additions & 0 deletions examples/datasets/example_dag_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from airflow import DAG
import dagfactory


config_file = "/usr/local/airflow/dags/datasets/example_dag_datasets.yml"
example_dag_factory = dagfactory.DagFactory(config_file)

# Creating task dependencies
example_dag_factory.clean_dags(globals())
example_dag_factory.generate_dags(globals())
54 changes: 54 additions & 0 deletions examples/datasets/example_dag_datasets.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
default:
default_args:
owner: "default_owner"
start_date: '2023-07-14'
retries: 1
retry_delay_sec: 300
concurrency: 1
max_active_runs: 1
dagrun_timeout_sec: 600
default_view: "tree"
orientation: "LR"

example_simple_dataset_producer_dag:
description: "Example DAG producer simple datasets"
schedule_interval: "0 5 * * *"
tasks:
task_1:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 1"
outlets: ['s3://bucket_example/raw/dataset1.json']
task_2:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 2"
dependencies: [task_1]
outlets: ['s3://bucket_example/raw/dataset2.json']

example_simple_dataset_consumer_dag:
description: "Example DAG consumer simple datasets"
schedule: ['s3://bucket_example/raw/dataset1.json', 's3://bucket_example/raw/dataset2.json']
tasks:
task_1:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 'consumer datasets'"

example_custom_config_dataset_producer_dag:
description: "Example DAG producer custom config datasets"
schedule_interval: "0 5 * * *"
tasks:
task_1:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 1"
outlets:
file: /usr/local/airflow/dags/datasets/example_config_datasets.yml
datasets: ['dataset_custom_1', 'dataset_custom_2']

example_custom_config_dataset_consumer_dag:
description: "Example DAG consumer custom config datasets"
schedule:
file: /usr/local/airflow/dags/datasets/example_config_datasets.yml
datasets: ['dataset_custom_1', 'dataset_custom_2']
tasks:
task_1:
operator: airflow.operators.bash_operator.BashOperator
bash_command: "echo 'consumer datasets'"
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,20 @@ def test_is_partial_duplicated():
utils.is_partial_duplicated(partial_kwargs, task_params)
except Exception as e:
assert str(e) == "Duplicated partial kwarg! It's already in task_params."

def test_open_and_filter_yaml_config_datasets():
datasets_names = ['dataset_custom_1', 'dataset_custom_2']
file_path = 'examples/datasets/example_config_datasets.yml'

actual = utils.get_datasets_uri_yaml_file(file_path, datasets_names)
expected = ['s3://bucket-cjmm/raw/dataset_custom_1', 's3://bucket-cjmm/raw/dataset_custom_2']

assert actual == expected

def test_open_and_filter_yaml_config_datasets_file_notfound():
datasets_names = ['dataset_custom_1', 'dataset_custom_2']
file_path = 'examples/datasets/not_found_example_config_datasets.yml'

with pytest.raises(Exception):
utils.get_datasets_uri_yaml_file(file_path, datasets_names)