From 4effd6f48b5b0fabde7e8bc731844a1cd258dc0e Mon Sep 17 00:00:00 2001 From: eladkal <45845474+eladkal@users.noreply.github.com> Date: Tue, 14 Mar 2023 09:27:10 +0200 Subject: [PATCH] Add `AwsToAwsBaseOperator` (#30044) * Add `AwsToAwsBaseOperator` followup on https://github.com/apache/airflow/pull/29452#issuecomment-1462741248 This PR preserve all current behavior but add the needed interface to be used for other transfer operators --- .../providers/amazon/aws/transfers/base.py | 69 +++++++++++++++++++ .../amazon/aws/transfers/dynamodb_to_s3.py | 35 ++-------- airflow/providers/amazon/provider.yaml | 4 +- tests/always/test_project_structure.py | 1 + .../amazon/aws/transfers/test_base.py | 58 ++++++++++++++++ .../aws/transfers/test_dynamodb_to_s3.py | 29 +++++++- 6 files changed, 163 insertions(+), 33 deletions(-) create mode 100644 airflow/providers/amazon/aws/transfers/base.py create mode 100644 tests/providers/amazon/aws/transfers/test_base.py diff --git a/airflow/providers/amazon/aws/transfers/base.py b/airflow/providers/amazon/aws/transfers/base.py new file mode 100644 index 000000000000..41e9a3a474cf --- /dev/null +++ b/airflow/providers/amazon/aws/transfers/base.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains base AWS to AWS transfer operator""" +from __future__ import annotations + +import warnings +from typing import Sequence + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.utils.types import NOTSET, ArgNotSet + +_DEPRECATION_MSG = ( + "The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead." +) + + +class AwsToAwsBaseOperator(BaseOperator): + """ + Base class for AWS to AWS transfer operators + + :param source_aws_conn_id: The Airflow connection used for AWS credentials + to access DynamoDB. If this is None or empty then the default boto3 + behaviour is used. If running Airflow in a distributed manner and + source_aws_conn_id is None or empty, then default boto3 configuration + would be used (and must be maintained on each worker node). + :param dest_aws_conn_id: The Airflow connection used for AWS credentials + to access S3. If this is not set then the source_aws_conn_id connection is used. + :param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id). + + """ + + template_fields: Sequence[str] = ( + "source_aws_conn_id", + "dest_aws_conn_id", + ) + + def __init__( + self, + *, + source_aws_conn_id: str | None = AwsBaseHook.default_conn_name, + dest_aws_conn_id: str | None | ArgNotSet = NOTSET, + aws_conn_id: str | None | ArgNotSet = NOTSET, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if not isinstance(aws_conn_id, ArgNotSet): + warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) + self.source_aws_conn_id = aws_conn_id + else: + self.source_aws_conn_id = source_aws_conn_id + self.dest_aws_conn_id = ( + self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id + ) diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py index d9ea01cac519..4067217d97d4 100644 --- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py @@ -22,7 +22,6 @@ from __future__ import annotations import json -import warnings from copy import copy from decimal import Decimal from os.path import getsize @@ -30,21 +29,15 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Sequence from uuid import uuid4 -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.utils.types import NOTSET, ArgNotSet +from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator if TYPE_CHECKING: from airflow.utils.context import Context -_DEPRECATION_MSG = ( - "The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead." -) - - class JSONEncoder(json.JSONEncoder): """Custom json encoder implementation""" @@ -74,7 +67,7 @@ def _upload_file_to_s3( ) -class DynamoDBToS3Operator(BaseOperator): +class DynamoDBToS3Operator(AwsToAwsBaseOperator): """ Replicates records from a DynamoDB table to S3. It scans a DynamoDB table and writes the received records to a file @@ -89,29 +82,20 @@ class DynamoDBToS3Operator(BaseOperator): :ref:`howto/transfer:DynamoDBToS3Operator` :param dynamodb_table_name: Dynamodb table to replicate data from - :param source_aws_conn_id: The Airflow connection used for AWS credentials - to access DynamoDB. If this is None or empty then the default boto3 - behaviour is used. If running Airflow in a distributed manner and - source_aws_conn_id is None or empty, then default boto3 configuration - would be used (and must be maintained on each worker node). :param s3_bucket_name: S3 bucket to replicate data to :param file_size: Flush file to s3 if file size >= file_size :param dynamodb_scan_kwargs: kwargs pass to :param s3_key_prefix: Prefix of s3 object key :param process_func: How we transforms a dynamodb item to bytes. By default we dump the json - :param dest_aws_conn_id: The Airflow connection used for AWS credentials - to access S3. If this is not set then the source_aws_conn_id connection is used. - :param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id). - """ # noqa: E501 template_fields: Sequence[str] = ( - "source_aws_conn_id", - "dest_aws_conn_id", + *AwsToAwsBaseOperator.template_fields, "s3_bucket_name", "s3_key_prefix", "dynamodb_table_name", ) + template_fields_renderers = { "dynamodb_scan_kwargs": "json", } @@ -120,14 +104,11 @@ def __init__( self, *, dynamodb_table_name: str, - source_aws_conn_id: str | None = AwsBaseHook.default_conn_name, s3_bucket_name: str, file_size: int, dynamodb_scan_kwargs: dict[str, Any] | None = None, s3_key_prefix: str = "", process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes, - dest_aws_conn_id: str | None | ArgNotSet = NOTSET, - aws_conn_id: str | None | ArgNotSet = NOTSET, **kwargs, ) -> None: super().__init__(**kwargs) @@ -137,14 +118,6 @@ def __init__( self.dynamodb_scan_kwargs = dynamodb_scan_kwargs self.s3_bucket_name = s3_bucket_name self.s3_key_prefix = s3_key_prefix - if not isinstance(aws_conn_id, ArgNotSet): - warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3) - self.source_aws_conn_id = aws_conn_id - else: - self.source_aws_conn_id = source_aws_conn_id - self.dest_aws_conn_id = ( - self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id - ) def execute(self, context: Context) -> None: hook = DynamoDBHook(aws_conn_id=self.source_aws_conn_id) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 431efca824e1..d7b7578dee10 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -567,7 +567,9 @@ transfers: target-integration-name: Common SQL how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/s3_to_sql.rst python-module: airflow.providers.amazon.aws.transfers.s3_to_sql - + - source-integration-name: Amazon Web Services + target-integration-name: Amazon Web Services + python-module: airflow.providers.amazon.aws.transfers.base extra-links: - airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 8c64cb57c985..d0aa829f38b5 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -393,6 +393,7 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest): "airflow.providers.amazon.aws.operators.ecs.EcsBaseOperator", "airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor", "airflow.providers.amazon.aws.sensors.eks.EksBaseSensor", + "airflow.providers.amazon.aws.transfers.base.AwsToAwsBaseOperator", } MISSING_EXAMPLES_FOR_CLASSES = { diff --git a/tests/providers/amazon/aws/transfers/test_base.py b/tests/providers/amazon/aws/transfers/test_base.py new file mode 100644 index 000000000000..0233b91d7c2b --- /dev/null +++ b/tests/providers/amazon/aws/transfers/test_base.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow import DAG +from airflow.models import DagRun, TaskInstance +from airflow.providers.amazon.aws.transfers.base import AwsToAwsBaseOperator +from airflow.utils import timezone + +DEFAULT_DATE = timezone.datetime(2020, 1, 1) + + +class TestAwsToAwsBaseOperator: + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_dag_id", default_args=args) + + def test_render_template(self): + operator = AwsToAwsBaseOperator( + task_id="dynamodb_to_s3_test_render", + dag=self.dag, + source_aws_conn_id="{{ ds }}", + dest_aws_conn_id="{{ ds }}", + ) + ti = TaskInstance(operator, run_id="something") + ti.dag_run = DagRun(run_id="something", execution_date=timezone.datetime(2020, 1, 1)) + ti.render_templates() + assert "2020-01-01" == getattr(operator, "source_aws_conn_id") + assert "2020-01-01" == getattr(operator, "dest_aws_conn_id") + + def test_deprecation(self): + with pytest.warns( + DeprecationWarning, + match="The aws_conn_id parameter has been deprecated." + " Use the source_aws_conn_id parameter instead.", + ): + AwsToAwsBaseOperator( + task_id="transfer", + dag=self.dag, + aws_conn_id="my_conn", + ) diff --git a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py index 22e9a1902041..17a240c15886 100644 --- a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py @@ -18,16 +18,20 @@ from __future__ import annotations import json +from datetime import datetime from decimal import Decimal from unittest.mock import MagicMock, patch import pytest +from airflow import DAG +from airflow.models import DagRun, TaskInstance +from airflow.providers.amazon.aws.transfers.base import _DEPRECATION_MSG from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import ( - _DEPRECATION_MSG, DynamoDBToS3Operator, JSONEncoder, ) +from airflow.utils import timezone class TestJSONEncoder: @@ -288,3 +292,26 @@ def test_dynamodb_to_s3_with_just_dest_aws_conn_id(self, mock_aws_dynamodb_hook, mock_aws_dynamodb_hook.assert_called_with(aws_conn_id="aws_default") mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id) + + def test_render_template(self): + dag = DAG("test_render_template_dag_id", start_date=datetime(2020, 1, 1)) + operator = DynamoDBToS3Operator( + task_id="dynamodb_to_s3_test_render", + dag=dag, + dynamodb_table_name="{{ ds }}", + s3_key_prefix="{{ ds }}", + s3_bucket_name="{{ ds }}", + file_size=4000, + source_aws_conn_id="{{ ds }}", + dest_aws_conn_id="{{ ds }}", + ) + ti = TaskInstance(operator, run_id="something") + ti.dag_run = DagRun( + dag_id=dag.dag_id, run_id="something", execution_date=timezone.datetime(2020, 1, 1) + ) + ti.render_templates() + assert "2020-01-01" == getattr(operator, "source_aws_conn_id") + assert "2020-01-01" == getattr(operator, "dest_aws_conn_id") + assert "2020-01-01" == getattr(operator, "s3_bucket_name") + assert "2020-01-01" == getattr(operator, "dynamodb_table_name") + assert "2020-01-01" == getattr(operator, "s3_key_prefix")