From 7a9b42016441b192af0a4de8935b615c80506048 Mon Sep 17 00:00:00 2001 From: Pankaj Singh <98807258+pankajastro@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:30:32 +0530 Subject: [PATCH] Add static check (#231) * Add static check * Update .github/workflows/test.yml --- .github/workflows/main.yml | 6 +- .github/workflows/test.yml | 35 ++++ .pre-commit-config.yaml | 31 ++++ Makefile | 16 +- dagfactory/__init__.py | 5 + dagfactory/dagbuilder.py | 335 ++++++++++++------------------------- dagfactory/dagfactory.py | 33 +--- dagfactory/utils.py | 38 ++--- pyproject.toml | 27 +++ 9 files changed, 223 insertions(+), 303 deletions(-) create mode 100644 .github/workflows/test.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index ad8f2a42..2bccebb2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,13 +25,9 @@ jobs: venv/bin/python3 -m pip install tox-gh-actions env: SLUGIFY_USES_TEXT_UNIDECODE: yes - - name: Check formatting with black - run: make fmt-check - - name: Lint with pylint - run: make lint - name: Test with tox run: make test - - name: Upload coverage to Codecov + - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 with: token: ${{secrets.CODECOV_TOKEN}} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..1ced9fe5 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,35 @@ +name: test + +on: + push: # Run on pushes to the default branch + branches: [main] + pull_request_target: # Also run on pull requests originated from forks + branches: [main] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + Authorize: + environment: ${{ github.event_name == 'pull_request_target' && + github.event.pull_request.head.repo.full_name != github.repository && + 'external' || 'internal' }} + runs-on: ubuntu-latest + steps: + - run: true + + Static-Check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + ref: ${{ github.event.pull_request.head.sha || github.ref }} + + - uses: actions/setup-python@v3 + with: + python-version: "3.12" + architecture: "x64" + + - run: pip3 install hatch + - run: hatch run tests.py3.12-2.10:static-check diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..7c8b3ecc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.9 + hooks: + - id: ruff + args: + - --fix + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + args: + - --unsafe + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + - id: pretty-format-json + args: [ "--autofix" ] + - id: trailing-whitespace + + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + args: [ "--config", "./pyproject.toml" ] diff --git a/Makefile b/Makefile index a55c3d75..6a7106b2 100644 --- a/Makefile +++ b/Makefile @@ -27,20 +27,6 @@ clean: ## Removes build and test artifacts @find . -name '*~' -exec rm -f {} + @find . -name '__pycache__' -exec rm -rf {} + -.PHONY: fmt -fmt: venv ## Formats all files with black - @echo "==> Formatting with Black" - @${PYTHON} -m black dagfactory - -.PHONY: fmt-check -fmt-check: venv ## Checks files were formatted with black - @echo "==> Formatting with Black" - @${PYTHON} -m black --check dagfactory - -.PHONY: lint -lint: venv ## Lint code with pylint - @${PYTHON} -m pylint dagfactory - .PHONY: test test: venv ## Runs unit tests @${PYTHON} -m tox @@ -57,4 +43,4 @@ docker-run: docker-build ## Runs local Airflow for testing .PHONY: docker-stop docker-stop: ## Stop Docker container - @docker stop dag_factory; docker rm dag_factory \ No newline at end of file + @docker stop dag_factory; docker rm dag_factory diff --git a/dagfactory/__init__.py b/dagfactory/__init__.py index 4803e708..4f14b38e 100644 --- a/dagfactory/__init__.py +++ b/dagfactory/__init__.py @@ -1,3 +1,8 @@ """Modules and methods to export for easier access""" from .dagfactory import DagFactory, load_yaml_dags + +__all__ = [ + "DagFactory", + "load_yaml_dags", +] diff --git a/dagfactory/dagbuilder.py b/dagfactory/dagbuilder.py index 689b3e8d..ef7babec 100644 --- a/dagfactory/dagbuilder.py +++ b/dagfactory/dagbuilder.py @@ -55,9 +55,9 @@ try: if version.parse(K8S_PROVIDER_VERSION) < version.parse("5.0.0"): from airflow.kubernetes.pod import Port - from airflow.kubernetes.volume_mount import VolumeMount - from airflow.kubernetes.volume import Volume from airflow.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv + from airflow.kubernetes.volume import Volume + from airflow.kubernetes.volume_mount import VolumeMount else: from kubernetes.client.models import V1ContainerPort as Port from kubernetes.client.models import ( @@ -68,21 +68,19 @@ ) from kubernetes.client.models import V1VolumeMount as VolumeMount from airflow.kubernetes.secret import Secret - from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import ( - KubernetesPodOperator, - ) + from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator except ImportError: - from airflow.contrib.kubernetes.secret import Secret from airflow.contrib.kubernetes.pod import Port - from airflow.contrib.kubernetes.volume_mount import VolumeMount - from airflow.contrib.kubernetes.volume import Volume from airflow.contrib.kubernetes.pod_runtime_info_env import PodRuntimeInfoEnv + from airflow.contrib.kubernetes.secret import Secret + from airflow.contrib.kubernetes.volume import Volume + from airflow.contrib.kubernetes.volume_mount import VolumeMount from airflow.contrib.operators.kubernetes_pod_operator import KubernetesPodOperator from kubernetes.client.models import V1Container, V1Pod -from dagfactory.exceptions import DagFactoryException, DagFactoryConfigException from dagfactory import utils +from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException # pylint: disable=ungrouped-imports,invalid-name # Disabling pylint's ungrouped-imports warning because this is a @@ -133,9 +131,7 @@ class DagBuilder: in the YAML file """ - def __init__( - self, dag_name: str, dag_config: Dict[str, Any], default_config: Dict[str, Any] - ) -> None: + def __init__(self, dag_name: str, dag_config: Dict[str, Any], default_config: Dict[str, Any]) -> None: self.dag_name: str = dag_name self.dag_config: Dict[str, Any] = deepcopy(dag_config) self.default_config: Dict[str, Any] = deepcopy(default_config) @@ -148,33 +144,20 @@ def get_dag_params(self) -> Dict[str, Any]: :returns: dict of dag parameters """ try: - dag_params: Dict[str, Any] = utils.merge_configs( - self.dag_config, self.default_config - ) + dag_params: Dict[str, Any] = utils.merge_configs(self.dag_config, self.default_config) except Exception as err: - raise DagFactoryConfigException( - "Failed to merge config with default config" - ) from err + raise DagFactoryConfigException("Failed to merge config with default config") from err dag_params["dag_id"]: str = self.dag_name - if dag_params.get("task_groups") and version.parse( - AIRFLOW_VERSION - ) < version.parse("2.0.0"): - raise DagFactoryConfigException( - "`task_groups` key can only be used with Airflow 2.x.x" - ) + if dag_params.get("task_groups") and version.parse(AIRFLOW_VERSION) < version.parse("2.0.0"): + raise DagFactoryConfigException("`task_groups` key can only be used with Airflow 2.x.x") - if ( - utils.check_dict_key(dag_params, "schedule_interval") - and dag_params["schedule_interval"] == "None" - ): + if utils.check_dict_key(dag_params, "schedule_interval") and dag_params["schedule_interval"] == "None": dag_params["schedule_interval"] = None # Convert from 'dagrun_timeout_sec: int' to 'dagrun_timeout: timedelta' if utils.check_dict_key(dag_params, "dagrun_timeout_sec"): - dag_params["dagrun_timeout"]: timedelta = timedelta( - seconds=dag_params["dagrun_timeout_sec"] - ) + dag_params["dagrun_timeout"]: timedelta = timedelta(seconds=dag_params["dagrun_timeout_sec"]) del dag_params["dagrun_timeout_sec"] # Convert from 'end_date: Union[str, datetime, date]' to 'end_date: datetime' @@ -191,88 +174,74 @@ def get_dag_params(self) -> Dict[str, Any]: del dag_params["default_args"]["retry_delay_sec"] if utils.check_dict_key(dag_params["default_args"], "sla_secs"): - dag_params["default_args"]["sla"]: timedelta = timedelta( - seconds=dag_params["default_args"]["sla_secs"] - ) + dag_params["default_args"]["sla"]: timedelta = timedelta(seconds=dag_params["default_args"]["sla_secs"]) del dag_params["default_args"]["sla_secs"] if utils.check_dict_key(dag_params["default_args"], "sla_miss_callback"): if isinstance(dag_params["default_args"]["sla_miss_callback"], str): - dag_params["default_args"]["sla_miss_callback"]: Callable = ( - import_string(dag_params["default_args"]["sla_miss_callback"]) + dag_params["default_args"]["sla_miss_callback"]: Callable = import_string( + dag_params["default_args"]["sla_miss_callback"] ) if utils.check_dict_key(dag_params["default_args"], "on_success_callback"): if isinstance(dag_params["default_args"]["on_success_callback"], str): - dag_params["default_args"]["on_success_callback"]: Callable = ( - import_string(dag_params["default_args"]["on_success_callback"]) + dag_params["default_args"]["on_success_callback"]: Callable = import_string( + dag_params["default_args"]["on_success_callback"] ) if utils.check_dict_key(dag_params["default_args"], "on_failure_callback"): if isinstance(dag_params["default_args"]["on_failure_callback"], str): - dag_params["default_args"]["on_failure_callback"]: Callable = ( - import_string(dag_params["default_args"]["on_failure_callback"]) + dag_params["default_args"]["on_failure_callback"]: Callable = import_string( + dag_params["default_args"]["on_failure_callback"] ) if utils.check_dict_key(dag_params["default_args"], "on_retry_callback"): if isinstance(dag_params["default_args"]["on_retry_callback"], str): - dag_params["default_args"]["on_retry_callback"]: Callable = ( - import_string(dag_params["default_args"]["on_retry_callback"]) + dag_params["default_args"]["on_retry_callback"]: Callable = import_string( + dag_params["default_args"]["on_retry_callback"] ) if utils.check_dict_key(dag_params, "sla_miss_callback"): if isinstance(dag_params["sla_miss_callback"], str): - dag_params["sla_miss_callback"]: Callable = import_string( - dag_params["sla_miss_callback"] - ) + dag_params["sla_miss_callback"]: Callable = import_string(dag_params["sla_miss_callback"]) if utils.check_dict_key(dag_params, "on_success_callback"): if isinstance(dag_params["on_success_callback"], str): - dag_params["on_success_callback"]: Callable = import_string( - dag_params["on_success_callback"] - ) + dag_params["on_success_callback"]: Callable = import_string(dag_params["on_success_callback"]) if utils.check_dict_key(dag_params, "on_failure_callback"): if isinstance(dag_params["on_failure_callback"], str): - dag_params["on_failure_callback"]: Callable = import_string( - dag_params["on_failure_callback"] - ) + dag_params["on_failure_callback"]: Callable = import_string(dag_params["on_failure_callback"]) - if utils.check_dict_key( - dag_params, "on_success_callback_name" - ) and utils.check_dict_key(dag_params, "on_success_callback_file"): + if utils.check_dict_key(dag_params, "on_success_callback_name") and utils.check_dict_key( + dag_params, "on_success_callback_file" + ): dag_params["on_success_callback"]: Callable = utils.get_python_callable( dag_params["on_success_callback_name"], dag_params["on_success_callback_file"], ) - if utils.check_dict_key( - dag_params, "on_failure_callback_name" - ) and utils.check_dict_key(dag_params, "on_failure_callback_file"): + if utils.check_dict_key(dag_params, "on_failure_callback_name") and utils.check_dict_key( + dag_params, "on_failure_callback_file" + ): dag_params["on_failure_callback"]: Callable = utils.get_python_callable( dag_params["on_failure_callback_name"], dag_params["on_failure_callback_file"], ) if utils.check_dict_key(dag_params, "template_searchpath"): - if isinstance( - dag_params["template_searchpath"], (list, str) - ) and utils.check_template_searchpath(dag_params["template_searchpath"]): - dag_params["template_searchpath"]: Union[str, List[str]] = dag_params[ - "template_searchpath" - ] + if isinstance(dag_params["template_searchpath"], (list, str)) and utils.check_template_searchpath( + dag_params["template_searchpath"] + ): + dag_params["template_searchpath"]: Union[str, List[str]] = dag_params["template_searchpath"] else: raise DagFactoryException("template_searchpath is not valid!") if utils.check_dict_key(dag_params, "render_template_as_native_obj"): if isinstance(dag_params["render_template_as_native_obj"], bool): - dag_params["render_template_as_native_obj"]: bool = dag_params[ - "render_template_as_native_obj" - ] + dag_params["render_template_as_native_obj"]: bool = dag_params["render_template_as_native_obj"] else: - raise DagFactoryException( - "render_template_as_native_obj should be bool type!" - ) + raise DagFactoryException("render_template_as_native_obj should be bool type!") try: # ensure that default_args dictionary contains key "start_date" @@ -283,9 +252,7 @@ def get_dag_params(self) -> Dict[str, Any]: ) except KeyError as err: # pylint: disable=line-too-long - raise DagFactoryConfigException( - f"{self.dag_name} config is missing start_date" - ) from err + raise DagFactoryConfigException(f"{self.dag_name} config is missing start_date") from err return dag_params @staticmethod @@ -299,15 +266,11 @@ def make_timetable(timetable: str, timetable_params: Dict[str, Any]) -> Timetabl # class is a Callable https://stackoverflow.com/a/34578836/3679900 timetable_obj: Callable[..., Timetable] = import_string(timetable) except Exception as err: - raise DagFactoryException( - f"Failed to import timetable {timetable} due to: {err}" - ) from err + raise DagFactoryException(f"Failed to import timetable {timetable} due to: {err}") from err try: schedule: Timetable = timetable_obj(**timetable_params) except Exception as err: - raise DagFactoryException( - f"Failed to create {timetable_obj} due to: {err}" - ) from err + raise DagFactoryException(f"Failed to create {timetable_obj} due to: {err}") from err return schedule # pylint: disable=too-many-branches @@ -327,9 +290,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: raise DagFactoryException(f"Failed to import operator: {operator}") from err # pylint: disable=too-many-nested-blocks try: - if issubclass( - operator_obj, (PythonOperator, BranchPythonOperator, PythonSensor) - ): + if issubclass(operator_obj, (PythonOperator, BranchPythonOperator, PythonSensor)): if ( not task_params.get("python_callable") and not task_params.get("python_callable_name") @@ -344,11 +305,9 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: " python_callable_file: !!python/name:my_module.my_func" ) if not task_params.get("python_callable"): - task_params["python_callable"]: Callable = ( - utils.get_python_callable( - task_params["python_callable_name"], - task_params["python_callable_file"], - ) + task_params["python_callable"]: Callable = utils.get_python_callable( + task_params["python_callable_name"], + task_params["python_callable_file"], ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -362,9 +321,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: # take precedence over the lambda function if issubclass(operator_obj, SqlSensor): # Success checks - if task_params.get("success_check_file") and task_params.get( - "success_check_name" - ): + if task_params.get("success_check_file") and task_params.get("success_check_name"): task_params["success"]: Callable = utils.get_python_callable( task_params["success_check_name"], task_params["success_check_file"], @@ -377,9 +334,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) del task_params["success_check_lambda"] # Failure checks - if task_params.get("failure_check_file") and task_params.get( - "failure_check_name" - ): + if task_params.get("failure_check_file") and task_params.get("failure_check_name"): task_params["failure"]: Callable = utils.get_python_callable( task_params["failure_check_name"], task_params["failure_check_file"], @@ -394,8 +349,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: if issubclass(operator_obj, HttpSensor): if not ( - task_params.get("response_check_name") - and task_params.get("response_check_file") + task_params.get("response_check_name") and task_params.get("response_check_file") ) and not task_params.get("response_check_lambda"): raise DagFactoryException( "Failed to create task. HttpSensor requires \ @@ -412,10 +366,8 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: del task_params["response_check_name"] del task_params["response_check_file"] else: - task_params["response_check"]: Callable = ( - utils.get_python_callable_lambda( - task_params["response_check_lambda"] - ) + task_params["response_check"]: Callable = utils.get_python_callable_lambda( + task_params["response_check_lambda"] ) # remove dag-factory specific parameters # Airflow 2.0 doesn't allow these to be passed to operator @@ -430,9 +382,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) task_params["ports"] = ( - [Port(**v) for v in task_params.get("ports")] - if task_params.get("ports") is not None - else None + [Port(**v) for v in task_params.get("ports")] if task_params.get("ports") is not None else None ) task_params["volume_mounts"] = ( [VolumeMount(**v) for v in task_params.get("volume_mounts")] @@ -446,10 +396,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: else None ) task_params["pod_runtime_info_envs"] = ( - [ - PodRuntimeInfoEnv(**v) - for v in task_params.get("pod_runtime_info_envs") - ] + [PodRuntimeInfoEnv(**v) for v in task_params.get("pod_runtime_info_envs")] if task_params.get("pod_runtime_info_envs") is not None else None ) @@ -477,9 +424,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: V1EnvVar( name=v.get("name"), value_from=V1EnvVarSource( - field_ref=V1ObjectFieldSelector( - field_path=v.get("field_path") - ) + field_ref=V1ObjectFieldSelector(field_path=v.get("field_path")) ), ) for v in task_params.get("pod_runtime_info_envs") @@ -488,9 +433,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: else None ) task_params["full_pod_spec"] = ( - V1Pod(**task_params.get("full_pod_spec")) - if task_params.get("full_pod_spec") is not None - else None + V1Pod(**task_params.get("full_pod_spec")) if task_params.get("full_pod_spec") is not None else None ) task_params["init_containers"] = ( [V1Container(**v) for v in task_params.get("init_containers")] @@ -498,26 +441,20 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: else None ) if utils.check_dict_key(task_params, "execution_timeout_secs"): - task_params["execution_timeout"]: timedelta = timedelta( - seconds=task_params["execution_timeout_secs"] - ) + task_params["execution_timeout"]: timedelta = timedelta(seconds=task_params["execution_timeout_secs"]) del task_params["execution_timeout_secs"] if utils.check_dict_key(task_params, "sla_secs"): - task_params["sla"]: timedelta = timedelta( - seconds=task_params["sla_secs"] - ) + task_params["sla"]: timedelta = timedelta(seconds=task_params["sla_secs"]) del task_params["sla_secs"] if utils.check_dict_key(task_params, "execution_delta_secs"): - task_params["execution_delta"]: timedelta = timedelta( - seconds=task_params["execution_delta_secs"] - ) + task_params["execution_delta"]: timedelta = timedelta(seconds=task_params["execution_delta_secs"]) del task_params["execution_delta_secs"] - if utils.check_dict_key( - task_params, "execution_date_fn_name" - ) and utils.check_dict_key(task_params, "execution_date_fn_file"): + if utils.check_dict_key(task_params, "execution_date_fn_name") and utils.check_dict_key( + task_params, "execution_date_fn_file" + ): task_params["execution_date_fn"]: Callable = utils.get_python_callable( task_params["execution_date_fn_name"], task_params["execution_date_fn_file"], @@ -526,53 +463,37 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: del task_params["execution_date_fn_file"] # on_execute_callback is an Airflow 2.0 feature - if utils.check_dict_key( - task_params, "on_execute_callback" - ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.0.0"): - task_params["on_execute_callback"]: Callable = import_string( - task_params["on_execute_callback"] - ) + if utils.check_dict_key(task_params, "on_execute_callback") and version.parse( + AIRFLOW_VERSION + ) >= version.parse("2.0.0"): + task_params["on_execute_callback"]: Callable = import_string(task_params["on_execute_callback"]) if utils.check_dict_key(task_params, "on_failure_callback"): - task_params["on_failure_callback"]: Callable = import_string( - task_params["on_failure_callback"] - ) + task_params["on_failure_callback"]: Callable = import_string(task_params["on_failure_callback"]) if utils.check_dict_key(task_params, "on_success_callback"): - task_params["on_success_callback"]: Callable = import_string( - task_params["on_success_callback"] - ) + task_params["on_success_callback"]: Callable = import_string(task_params["on_success_callback"]) if utils.check_dict_key(task_params, "on_retry_callback"): - task_params["on_retry_callback"]: Callable = import_string( - task_params["on_retry_callback"] - ) + task_params["on_retry_callback"]: Callable = import_string(task_params["on_retry_callback"]) # use variables as arguments on operator if utils.check_dict_key(task_params, "variables_as_arguments"): - variables: List[Dict[str, str]] = task_params.get( - "variables_as_arguments" - ) + variables: List[Dict[str, str]] = task_params.get("variables_as_arguments") for variable in variables: if Variable.get(variable["variable"], default_var=None) is not None: - task_params[variable["attribute"]] = Variable.get( - variable["variable"], default_var=None - ) + task_params[variable["attribute"]] = Variable.get(variable["variable"], default_var=None) del task_params["variables_as_arguments"] if ( - utils.check_dict_key(task_params, "expand") - or utils.check_dict_key(task_params, "partial") + utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial") ) and version.parse(AIRFLOW_VERSION) < version.parse("2.3.0"): - raise DagFactoryConfigException( - "Dynamic task mapping available only in Airflow >= 2.3.0" - ) + raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0") expand_kwargs: Dict[str, Union[Dict[str, Any], Any]] = {} # expand available only in airflow >= 2.3.0 if ( - utils.check_dict_key(task_params, "expand") - or utils.check_dict_key(task_params, "partial") + utils.check_dict_key(task_params, "expand") or utils.check_dict_key(task_params, "partial") ) and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"): # Getting expand and partial kwargs from task_params ( @@ -582,22 +503,18 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: ) = utils.get_expand_partial_kwargs(task_params) # If there are partial_kwargs we should merge them with existing task_params - if partial_kwargs and not utils.is_partial_duplicated( - partial_kwargs, task_params - ): + if partial_kwargs and not utils.is_partial_duplicated(partial_kwargs, task_params): task_params.update(partial_kwargs) - if utils.check_dict_key(task_params, "outlets") and version.parse( - AIRFLOW_VERSION - ) >= version.parse("2.4.0"): - if utils.check_dict_key( - task_params["outlets"], "file" - ) and utils.check_dict_key(task_params["outlets"], "datasets"): + if utils.check_dict_key(task_params, "outlets") and version.parse(AIRFLOW_VERSION) >= version.parse( + "2.4.0" + ): + if utils.check_dict_key(task_params["outlets"], "file") and utils.check_dict_key( + task_params["outlets"], "datasets" + ): file = task_params["outlets"]["file"] datasets_filter = task_params["outlets"]["datasets"] - datasets_uri = utils.get_datasets_uri_yaml_file( - file, datasets_filter - ) + datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) del task_params["outlets"]["file"] del task_params["outlets"]["datasets"] @@ -616,9 +533,7 @@ def make_task(operator: str, task_params: Dict[str, Any]) -> BaseOperator: return task @staticmethod - def make_task_groups( - task_groups: Dict[str, Any], dag: DAG - ) -> Dict[str, "TaskGroup"]: + def make_task_groups(task_groups: Dict[str, Any], dag: DAG) -> Dict[str, "TaskGroup"]: """Takes a DAG and task group configurations. Creates TaskGroup instances. :param task_groups: Task group configuration from the YAML configuration file. @@ -629,13 +544,7 @@ def make_task_groups( for task_group_name, task_group_conf in task_groups.items(): task_group_conf["group_id"] = task_group_name task_group_conf["dag"] = dag - task_group = TaskGroup( - **{ - k: v - for k, v in task_group_conf.items() - if k not in SYSTEM_PARAMS - } - ) + task_group = TaskGroup(**{k: v for k, v in task_group_conf.items() if k not in SYSTEM_PARAMS}) task_groups_dict[task_group.group_id] = task_group return task_groups_dict @@ -662,18 +571,12 @@ def set_dependencies( group_id = conf["task_group"].group_id name = f"{group_id}.{name}" if conf.get("dependencies"): - source: Union[BaseOperator, "TaskGroup"] = ( - tasks_and_task_groups_instances[name] - ) + source: Union[BaseOperator, "TaskGroup"] = tasks_and_task_groups_instances[name] for dep in conf["dependencies"]: if tasks_and_task_groups_config[dep].get("task_group"): - group_id = tasks_and_task_groups_config[dep][ - "task_group" - ].group_id + group_id = tasks_and_task_groups_config[dep]["task_group"].group_id dep = f"{group_id}.{dep}" - dep: Union[BaseOperator, "TaskGroup"] = ( - tasks_and_task_groups_instances[dep] - ) + dep: Union[BaseOperator, "TaskGroup"] = tasks_and_task_groups_instances[dep] source.set_upstream(dep) @staticmethod @@ -714,12 +617,8 @@ def build(self) -> Dict[str, Union[str, DAG]]: dag_kwargs["dag_id"] = dag_params["dag_id"] - if not dag_params.get("timetable") and not utils.check_dict_key( - dag_params, "schedule" - ): - dag_kwargs["schedule_interval"] = dag_params.get( - "schedule_interval", timedelta(days=1) - ) + if not dag_params.get("timetable") and not utils.check_dict_key(dag_params, "schedule"): + dag_kwargs["schedule_interval"] = dag_params.get("schedule_interval", timedelta(days=1)) if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.11"): dag_kwargs["description"] = dag_params.get("description", None) @@ -765,9 +664,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: # Jinja NativeEnvironment support has been added in Airflow 2.1.0 if version.parse(AIRFLOW_VERSION) >= version.parse("2.1.0"): - dag_kwargs["render_template_as_native_obj"] = dag_params.get( - "render_template_as_native_obj", False - ) + dag_kwargs["render_template_as_native_obj"] = dag_params.get("render_template_as_native_obj", False) dag_kwargs["sla_miss_callback"] = dag_params.get("sla_miss_callback", None) @@ -781,18 +678,16 @@ def build(self) -> Dict[str, Union[str, DAG]]: dag_kwargs["access_control"] = dag_params.get("access_control", None) - dag_kwargs["is_paused_upon_creation"] = dag_params.get( - "is_paused_upon_creation", None - ) + dag_kwargs["is_paused_upon_creation"] = dag_params.get("is_paused_upon_creation", None) if ( utils.check_dict_key(dag_params, "schedule") and not utils.check_dict_key(dag_params, "schedule_interval") and version.parse(AIRFLOW_VERSION) >= version.parse("2.4.0") ): - if utils.check_dict_key( - dag_params["schedule"], "file" - ) and utils.check_dict_key(dag_params["schedule"], "datasets"): + if utils.check_dict_key(dag_params["schedule"], "file") and utils.check_dict_key( + dag_params["schedule"], "datasets" + ): file = dag_params["schedule"]["file"] datasets_filter = dag_params["schedule"]["datasets"] datasets_uri = utils.get_datasets_uri_yaml_file(file, datasets_filter) @@ -812,21 +707,15 @@ def build(self) -> Dict[str, Union[str, DAG]]: if not os.path.isabs(dag_params.get("doc_md_file_path")): raise DagFactoryException("`doc_md_file_path` must be absolute path") - with open( - dag_params.get("doc_md_file_path"), "r", encoding="utf-8" - ) as file: + with open(dag_params.get("doc_md_file_path"), "r", encoding="utf-8") as file: dag.doc_md = file.read() - if dag_params.get("doc_md_python_callable_file") and dag_params.get( - "doc_md_python_callable_name" - ): + if dag_params.get("doc_md_python_callable_file") and dag_params.get("doc_md_python_callable_name"): doc_md_callable = utils.get_python_callable( dag_params.get("doc_md_python_callable_name"), dag_params.get("doc_md_python_callable_file"), ) - dag.doc_md = doc_md_callable( - **dag_params.get("doc_md_python_arguments", {}) - ) + dag.doc_md = doc_md_callable(**dag_params.get("doc_md_python_arguments", {})) # tags parameter introduced in Airflow 1.10.8 if version.parse(AIRFLOW_VERSION) >= version.parse("1.10.8"): @@ -838,9 +727,7 @@ def build(self) -> Dict[str, Union[str, DAG]]: dag.is_dagfactory_auto_generated = True # create dictionary of task groups - task_groups_dict: Dict[str, "TaskGroup"] = self.make_task_groups( - dag_params.get("task_groups", {}), dag - ) + task_groups_dict: Dict[str, "TaskGroup"] = self.make_task_groups(dag_params.get("task_groups", {}), dag) # create dictionary to track tasks and set dependencies tasks_dict: Dict[str, BaseOperator] = {} @@ -850,32 +737,20 @@ def build(self) -> Dict[str, Union[str, DAG]]: task_conf["dag"]: DAG = dag # add task to task_group if task_groups_dict and task_conf.get("task_group_name"): - task_conf["task_group"] = task_groups_dict[ - task_conf.get("task_group_name") - ] + task_conf["task_group"] = task_groups_dict[task_conf.get("task_group_name")] # Dynamic task mapping available only in Airflow >= 2.3.0 - if (task_conf.get("expand") or task_conf.get("partial")) and version.parse( - AIRFLOW_VERSION - ) < version.parse("2.3.0"): - raise DagFactoryConfigException( - "Dynamic task mapping available only in Airflow >= 2.3.0" - ) + if (task_conf.get("expand") or task_conf.get("partial")) and version.parse(AIRFLOW_VERSION) < version.parse( + "2.3.0" + ): + raise DagFactoryConfigException("Dynamic task mapping available only in Airflow >= 2.3.0") # replace 'task_id.output' or 'XComArg(task_id)' with XComArg(task_instance) object - if task_conf.get("expand") and version.parse( - AIRFLOW_VERSION - ) >= version.parse("2.3.0"): + if task_conf.get("expand") and version.parse(AIRFLOW_VERSION) >= version.parse("2.3.0"): task_conf = self.replace_expand_values(task_conf, tasks_dict) - params: Dict[str, Any] = { - k: v for k, v in task_conf.items() if k not in SYSTEM_PARAMS - } - task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task( - operator=operator, task_params=params - ) + params: Dict[str, Any] = {k: v for k, v in task_conf.items() if k not in SYSTEM_PARAMS} + task: Union[BaseOperator, MappedOperator] = DagBuilder.make_task(operator=operator, task_params=params) tasks_dict[task.task_id]: BaseOperator = task # set task dependencies after creating tasks - self.set_dependencies( - tasks, tasks_dict, dag_params.get("task_groups", {}), task_groups_dict - ) + self.set_dependencies(tasks, tasks_dict, dag_params.get("task_groups", {}), task_groups_dict) return {"dag_id": dag_params["dag_id"], "dag": dag} diff --git a/dagfactory/dagfactory.py b/dagfactory/dagfactory.py index 9833500b..40b3d53b 100644 --- a/dagfactory/dagfactory.py +++ b/dagfactory/dagfactory.py @@ -11,8 +11,7 @@ from airflow.models import DAG from dagfactory.dagbuilder import DagBuilder -from dagfactory.exceptions import DagFactoryException, DagFactoryConfigException - +from dagfactory.exceptions import DagFactoryConfigException, DagFactoryException # these are params that cannot be a dag name SYSTEM_PARAMS: List[str] = ["default", "task_groups"] @@ -29,17 +28,11 @@ class DagFactory: :type config: dict """ - def __init__( - self, config_filepath: Optional[str] = None, config: Optional[dict] = None - ) -> None: - assert bool(config_filepath) ^ bool( - config - ), "Either `config_filepath` or `config` should be provided" + def __init__(self, config_filepath: Optional[str] = None, config: Optional[dict] = None) -> None: + assert bool(config_filepath) ^ bool(config), "Either `config_filepath` or `config` should be provided" if config_filepath: DagFactory._validate_config_filepath(config_filepath=config_filepath) - self.config: Dict[str, Any] = DagFactory._load_config( - config_filepath=config_filepath - ) + self.config: Dict[str, Any] = DagFactory._load_config(config_filepath=config_filepath) if config: self.config: Dict[str, Any] = config @@ -49,9 +42,7 @@ def _validate_config_filepath(config_filepath: str) -> None: Validates config file path is absolute """ if not os.path.isabs(config_filepath): - raise DagFactoryConfigException( - "DAG Factory `config_filepath` must be absolute path" - ) + raise DagFactoryConfigException("DAG Factory `config_filepath` must be absolute path") @staticmethod def _load_config(config_filepath: str) -> Dict[str, Any]: @@ -83,11 +74,7 @@ def get_dag_configs(self) -> Dict[str, Dict[str, Any]]: :returns: dict with configuration for dags """ - return { - dag: self.config[dag] - for dag in self.config.keys() - if dag not in SYSTEM_PARAMS - } + return {dag: self.config[dag] for dag in self.config.keys() if dag not in SYSTEM_PARAMS} def get_default_config(self) -> Dict[str, Any]: """ @@ -115,9 +102,7 @@ def build_dags(self) -> Dict[str, DAG]: dag: Dict[str, Union[str, DAG]] = dag_builder.build() dags[dag["dag_id"]]: DAG = dag["dag"] except Exception as err: - raise DagFactoryException( - f"Failed to generate dag {dag_name}. verify config is correct" - ) from err + raise DagFactoryException(f"Failed to generate dag {dag_name}. verify config is correct") from err return dags @@ -190,9 +175,7 @@ def load_yaml_dags( suffix = [".yaml", ".yml"] candidate_dag_files = [] for suf in suffix: - candidate_dag_files = chain( - candidate_dag_files, Path(dags_folder).rglob(f"*{suf}") - ) + candidate_dag_files = chain(candidate_dag_files, Path(dags_folder).rglob(f"*{suf}")) for config_file_path in candidate_dag_files: config_file_abs_path = str(config_file_path.absolute()) diff --git a/dagfactory/utils.py b/dagfactory/utils.py index 29695fd7..94cb1259 100644 --- a/dagfactory/utils.py +++ b/dagfactory/utils.py @@ -8,17 +8,15 @@ import types from datetime import date, datetime, timedelta from pathlib import Path -from typing import Any, AnyStr, Dict, Match, Optional, Pattern, Union, List, Tuple -import yaml +from typing import Any, AnyStr, Dict, List, Match, Optional, Pattern, Tuple, Union import pendulum +import yaml from dagfactory.exceptions import DagFactoryException -def get_datetime( - date_value: Union[str, datetime, date], timezone: str = "UTC" -) -> datetime: +def get_datetime(date_value: Union[str, datetime, date], timezone: str = "UTC") -> datetime: """ Takes value from DAG config and generates valid datetime. Defaults to today, if not a valid date or relative time (1 hours, 1 days, etc.) @@ -37,20 +35,14 @@ def get_datetime( if isinstance(date_value, datetime): return date_value.replace(tzinfo=local_tz) if isinstance(date_value, date): - return datetime.combine(date=date_value, time=datetime.min.time()).replace( - tzinfo=local_tz - ) + return datetime.combine(date=date_value, time=datetime.min.time()).replace(tzinfo=local_tz) # Try parsing as date string try: return pendulum.parse(date_value).replace(tzinfo=local_tz) except pendulum.parsing.exceptions.ParserError: # Try parsing as relative time string rel_delta: timedelta = get_time_delta(date_value) - now: datetime = ( - datetime.today() - .replace(hour=0, minute=0, second=0, microsecond=0) - .replace(tzinfo=local_tz) - ) + now: datetime = datetime.today().replace(hour=0, minute=0, second=0, microsecond=0).replace(tzinfo=local_tz) if not rel_delta: return now return now - rel_delta @@ -86,9 +78,7 @@ def get_time_delta(time_string: str) -> timedelta: return timedelta(**time_params) -def merge_configs( - config: Dict[str, Any], default_config: Dict[str, Any] -) -> Dict[str, Any]: +def merge_configs(config: Dict[str, Any], default_config: Dict[str, Any]) -> Dict[str, Any]: """ Merges a `default` config with DAG config. Used to set default values for a group of DAGs. @@ -184,9 +174,7 @@ def convert_to_snake_case(input_string: str) -> str: """ # pylint: disable=line-too-long # source: https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/ - return "".join("_" + i.lower() if i.isupper() else i for i in input_string).lstrip( - "_" - ) + return "".join("_" + i.lower() if i.isupper() else i for i in input_string).lstrip("_") def check_template_searchpath(template_searchpath: Union[str, List[str]]) -> bool: @@ -243,9 +231,7 @@ def get_expand_partial_kwargs(task_params: Dict[str, Any]) -> Tuple[ return task_params, expand_kwargs, partial_kwargs -def is_partial_duplicated( - partial_kwargs: Dict[str, Any], task_params: Dict[str, Any] -) -> bool: +def is_partial_duplicated(partial_kwargs: Dict[str, Any], task_params: Dict[str, Any]) -> bool: """ Check if there are duplicated keys in partial_kwargs and task_params :param partial_kwargs: a partial kwargs to check duplicates in @@ -259,9 +245,7 @@ def is_partial_duplicated( for key in partial_kwargs: task_duplicated_kwarg = task_params.get(key, None) if task_duplicated_kwarg is not None: - raise DagFactoryException( - "Duplicated partial kwarg! It's already in task_params." - ) + raise DagFactoryException("Duplicated partial kwarg! It's already in task_params.") return False @@ -282,9 +266,7 @@ def get_datasets_uri_yaml_file(file_path: str, datasets_filter: str) -> List[str datasets = data.get("datasets", []) datasets_result_uri = [ - dataset["uri"] - for dataset in datasets - if dataset["name"] in datasets_filter and "uri" in dataset + dataset["uri"] for dataset in datasets if dataset["name"] in datasets_filter and "uri" in dataset ] return datasets_result_uri except FileNotFoundError as err: diff --git a/pyproject.toml b/pyproject.toml index 21ce1045..af8b4a87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,14 @@ dev = [ "pytest-cov", "tox", ] +tests = [ + "pre-commit" +] + +[tool.hatch.envs.tests] +dependencies = [ + "dag-factory[tests]" + ] [project.urls] Source = "https://github.com/astronomer/dag-factory" @@ -55,3 +63,22 @@ include = ["dagfactory"] [tool.hatch.build.targets.wheel] packages = ["dagfactory"] + +[[tool.hatch.envs.tests.matrix]] +python = ["3.9", "3.10", "3.11", "3.12"] +airflow = ["2.8", "2.9", "2.10"] + +[tool.hatch.envs.tests.scripts] +static-check = " pre-commit run --files dagfactory/*" + +[tool.black] +line-length = 120 +target-version = ['py39', 'py310', 'py311', 'py312'] + +[tool.ruff] +line-length = 120 +[tool.ruff.lint] +select = ["C901", "D300", "I", "F"] +ignore = ["F541", "C901"] +[tool.ruff.lint.mccabe] +max-complexity = 10