diff --git a/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py new file mode 100644 index 000000000000..ebcd49bdc411 --- /dev/null +++ b/airflow/providers/amazon/aws/transfers/azure_blob_to_s3.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import tempfile +from typing import TYPE_CHECKING, Sequence + +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class AzureBlobStorageToS3Operator(BaseOperator): + """ + Operator transfers data from Azure Blob Storage to specified bucket in Amazon S3. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AzureBlobStorageToGCSOperator` + + :param wasb_conn_id: Reference to the wasb connection. + :param container_name: Name of the container + :param prefix: Prefix string which filters objects whose name begin with + this prefix. (templated) + :param delimiter: The delimiter by which you want to filter the objects. (templated) + For e.g to lists the CSV files from in a directory in GCS you would use + delimiter='.csv'. + :param aws_conn_id: Connection id of the S3 connection to use + :param dest_s3_key: The base S3 key to be used to store the files. (templated) + :param dest_verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be + verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :param dest_s3_extra_args: Extra arguments that may be passed to the download/upload operations. + :param replace: Whether or not to verify the existence of the files in the + destination bucket. + By default is set to False + If set to True, will upload all the files replacing the existing ones in + the destination bucket. + If set to False, will upload only the files that are in the origin but not + in the destination bucket. + :param s3_acl_policy: Optional The string to specify the canned ACL policy for the + object to be uploaded in S3 + :param wasb_extra_kargs: kwargs to pass to WasbHook + :param s3_extra_kargs: kwargs to pass to S3Hook + """ + + template_fields: Sequence[str] = ( + "container_name", + "prefix", + "delimiter", + "dest_s3_key", + ) + + def __init__( + self, + *, + wasb_conn_id: str = "wasb_default", + container_name: str, + prefix: str | None = None, + delimiter: str = "", + aws_conn_id: str = "aws_default", + dest_s3_key: str, + dest_verify: str | bool | None = None, + dest_s3_extra_args: dict | None = None, + replace: bool = False, + s3_acl_policy: str | None = None, + wasb_extra_args: dict = {}, + s3_extra_args: dict = {}, + **kwargs, + ) -> None: + super().__init__(**kwargs) + + self.wasb_conn_id = wasb_conn_id + self.container_name = container_name + self.prefix = prefix + self.delimiter = delimiter + self.aws_conn_id = aws_conn_id + self.dest_s3_key = dest_s3_key + self.dest_verify = dest_verify + self.dest_s3_extra_args = dest_s3_extra_args or {} + self.replace = replace + self.s3_acl_policy = s3_acl_policy + self.wasb_extra_args = wasb_extra_args + self.s3_extra_args = s3_extra_args + + def execute(self, context: Context) -> list[str]: + # list all files in the Azure Blob Storage container + wasb_hook = WasbHook(wasb_conn_id=self.wasb_conn_id, **self.wasb_extra_args) + s3_hook = S3Hook( + aws_conn_id=self.aws_conn_id, + verify=self.dest_verify, + extra_args=self.dest_s3_extra_args, + **self.s3_extra_args, + ) + + self.log.info( + f"Getting list of the files in Container: {self.container_name}; " + f"Prefix: {self.prefix}; Delimiter: {self.delimiter};" + ) + + files = wasb_hook.get_blobs_list_recursive( + container_name=self.container_name, prefix=self.prefix, endswith=self.delimiter + ) + + if not self.replace: + # if we are not replacing -> list all files in the S3 bucket + # and only keep those files which are present in + # Azure Blob Storage and not in S3 + bucket_name, prefix = S3Hook.parse_s3_url(self.dest_s3_key) + # look for the bucket and the prefix to avoid look into + # parent directories/keys + existing_files = s3_hook.list_keys(bucket_name, prefix=prefix) + # in case that no files exists, return an empty array to avoid errors + existing_files = existing_files if existing_files is not None else [] + # remove the prefix for the existing files to allow the match + existing_files = [file.replace(f"{prefix}/", "", 1) for file in existing_files] + files = list(set(files) - set(existing_files)) + + if files: + for file in files: + with tempfile.NamedTemporaryFile() as temp_file: + + dest_key = os.path.join(self.dest_s3_key, file) + self.log.info("Downloading data from blob: %s", file) + wasb_hook.get_file( + file_path=temp_file.name, + container_name=self.container_name, + blob_name=file, + ) + + self.log.info("Uploading data to s3: %s", dest_key) + s3_hook.load_file( + filename=temp_file.name, + key=dest_key, + replace=self.replace, + acl_policy=self.s3_acl_policy, + ) + self.log.info("All done, uploaded %d files to S3", len(files)) + else: + self.log.info("All files are already in sync!") + return files diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index e4f16ce39846..77d1c544a8ce 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -622,6 +622,10 @@ transfers: - source-integration-name: Amazon Web Services target-integration-name: Amazon Web Services python-module: airflow.providers.amazon.aws.transfers.base + - source-integration-name: Microsoft Azure Blob Storage + target-integration-name: Amazon Simple Storage Service (S3) + how-to-guide: /docs/apache-airflow-providers-amazon/transfer/azure_blob_to_s3.rst + python-module: airflow.providers.amazon.aws.transfers.azure_blob_to_s3 extra-links: - airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index c1d615810b96..ad9c4754c525 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -292,6 +292,33 @@ def get_blobs_list( blob_list.append(blob.name) return blob_list + def get_blobs_list_recursive( + self, + container_name: str, + prefix: str | None = None, + include: list[str] | None = None, + endswith: str = "", + **kwargs, + ) -> list: + """ + List blobs in a given container. + + :param container_name: The name of the container + :param prefix: Filters the results to return only blobs whose names + begin with the specified prefix. + :param include: Specifies one or more additional datasets to include in the + response. Options include: ``snapshots``, ``metadata``, ``uncommittedblobs``, + ``copy`, ``deleted``. + :param delimiter: filters objects based on the delimiter (for e.g '.csv') + """ + container = self._get_container_client(container_name) + blob_list = [] + blobs = container.list_blobs(name_starts_with=prefix, include=include, **kwargs) + for blob in blobs: + if blob.name.endswith(endswith): + blob_list.append(blob.name) + return blob_list + def load_file( self, file_path: str, diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 0081b6a2a795..14dd7576eb18 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -311,7 +311,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/__init__.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap " + "common.sql exasol ftp google http imap microsoft.azure " "mongo mysql postgres salesforce ssh", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -325,7 +325,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "run-amazon-tests": "true", "parallel-test-types-list-as-string": "Providers[amazon] Always " - "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap," + "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http,imap,microsoft.azure," "mongo,mysql,postgres,salesforce,ssh] Providers[google]", }, id="Providers tests run including amazon tests if amazon provider files changed", @@ -353,7 +353,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/file.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap " + "common.sql exasol ftp google http imap microsoft.azure " "mongo mysql postgres salesforce ssh", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -368,7 +368,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Providers[amazon] Always " "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp," - "http,imap,mongo,mysql,postgres,salesforce,ssh] Providers[google]", + "http,imap,microsoft.azure,mongo,mysql,postgres,salesforce,ssh] Providers[google]", }, id="Providers tests run including amazon tests if amazon provider files changed", ), diff --git a/docs/apache-airflow-providers-amazon/transfer/azure_blob_to_s3.rst b/docs/apache-airflow-providers-amazon/transfer/azure_blob_to_s3.rst new file mode 100644 index 000000000000..a3b9df5eb1ee --- /dev/null +++ b/docs/apache-airflow-providers-amazon/transfer/azure_blob_to_s3.rst @@ -0,0 +1,52 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +=================================================== +Azure Blob Storage to Amazon S3 transfer operator +=================================================== + +Use the ``AzureBlobStorageToS3Operator`` transfer to copy the data from Azure Blob Storage to Amazon Simple Storage Service (S3). + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Operators +--------- + +.. _howto/operator:AzureBlobStorageToS3Operator: + +Azure Blob Storage to Amazon S3 +================================= + +To copy data from an Azure Blob Storage container to an Amazon S3 bucket you can use +:class:`~airflow.providers.amazon.aws.transfers.azure_blob_to_s3.AzureBlobStorageToS3Operator` + +Example usage: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_azure_blob_to_s3.py + :language: python + :dedent: 4 + :start-after: [START howto_transfer_azure_blob_to_s3] + :end-before: [END howto_transfer_azure_blob_to_s3] + +Reference +--------- + +* `Azure Blob Storage client library `__ +* `AWS boto3 library documentation for Amazon S3 `__ diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 5ce7ccd295e4..eb0729aa5f71 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -45,6 +45,7 @@ "google", "http", "imap", + "microsoft.azure", "mongo", "salesforce", "ssh" diff --git a/tests/providers/amazon/aws/transfers/test_azure_blob_to_s3.py b/tests/providers/amazon/aws/transfers/test_azure_blob_to_s3.py new file mode 100644 index 000000000000..f34877e18722 --- /dev/null +++ b/tests/providers/amazon/aws/transfers/test_azure_blob_to_s3.py @@ -0,0 +1,218 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from io import RawIOBase +from unittest import mock + +from moto import mock_s3 + +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.transfers.azure_blob_to_s3 import AzureBlobStorageToS3Operator + +TASK_ID = "test-gcs-list-operator" +CONTAINER_NAME = "test-container" +DELIMITER = ".csv" +PREFIX = "TEST" +S3_BUCKET = "s3://bucket/" +MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"] +S3_ACL_POLICY = "private-read" + + +def _create_test_bucket(): + hook = S3Hook(aws_conn_id="airflow_gcs_test") + # We're mocking all actual AWS calls and don't need a connection. + # This avoids an Airflow warning about connection cannot be found. + hook.get_connection = lambda _: None + bucket = hook.get_bucket("bucket") + bucket.create() + return hook, bucket + + +@mock_s3 +class TestAzureBlobToS3Operator: + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + def test_operator_all_file_upload(self, mock_hook): + """ + Destination bucket has no file (of interest) common with origin bucket i.e + Azure - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + S3 - [] + """ + mock_hook.return_value.get_blobs_list_recursive.return_value = MOCK_FILES + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + replace=False, + ) + + hook, _ = _create_test_bucket() + uploaded_files = operator.execute(None) + + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/")) + + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + def test_operator_incremental_file_upload_without_replace(self, mock_hook): + """ + Destination bucket has subset of files common with origin bucket i.e + Azure - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + S3 - ["TEST1.csv"] + """ + mock_hook.return_value.get_blobs_list_recursive.return_value = MOCK_FILES + get_file = mock_hook.return_value.get_file + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + # without replace + replace=False, + ) + + hook, bucket = _create_test_bucket() + # uploading only first file + bucket.put_object(Key=MOCK_FILES[0], Body=b"testing") + + uploaded_files = operator.execute(None) + + assert sorted(MOCK_FILES[1:]) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/")) + assert get_file.call_count == len(MOCK_FILES[1:]) + + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + def test_operator_incremental_file_upload_with_replace(self, mock_hook): + """ + Destination bucket has subset of files common with origin bucket i.e + Azure - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + S3 - ["TEST1.csv"] + """ + mock_hook.return_value.get_blobs_list_recursive.return_value = MOCK_FILES + get_file = mock_hook.return_value.get_file + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + # with replace + replace=True, + ) + + hook, bucket = _create_test_bucket() + # uploading only first file + bucket.put_object(Key=MOCK_FILES[0], Body=b"testing") + + uploaded_files = operator.execute(None) + + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/")) + assert get_file.call_count == len(MOCK_FILES) + + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + def test_operator_no_file_upload_without_replace(self, mock_hook): + """ + Destination bucket has all the files common with origin bucket i.e + Azure - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + S3 - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + """ + mock_hook.return_value.get_blobs_list_recursive.return_value = MOCK_FILES + get_file = mock_hook.return_value.get_file + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + replace=False, + ) + + hook, bucket = _create_test_bucket() + # uploading all the files + for mock_file in MOCK_FILES: + bucket.put_object(Key=mock_file, Body=b"testing") + + uploaded_files = operator.execute(None) + + assert sorted([]) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/")) + assert get_file.call_count == 0 + + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + def test_operator_no_file_upload_with_replace(self, mock_hook): + """ + Destination bucket has all the files common with origin bucket i.e + Azure - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + S3 - ["TEST1.csv", "TEST2.csv", "TEST3.csv"] + """ + mock_hook.return_value.get_blobs_list_recursive.return_value = MOCK_FILES + get_file = mock_hook.return_value.get_file + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + replace=True, + ) + + hook, bucket = _create_test_bucket() + # uploading all the files + for mock_file in MOCK_FILES: + bucket.put_object(Key=mock_file, Body=b"testing") + + uploaded_files = operator.execute(None) + + assert sorted(MOCK_FILES) == sorted(uploaded_files) + assert sorted(MOCK_FILES) == sorted(hook.list_keys("bucket", delimiter="/")) + # this ensures that upload happened + assert get_file.call_count == len(MOCK_FILES) + + @mock.patch("tempfile.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.WasbHook") + @mock.patch("airflow.providers.amazon.aws.transfers.azure_blob_to_s3.S3Hook") + def test_operator_should_pass_dest_s3_extra_args_and_s3_acl_policy( + self, s3_hook_mock, wasb_hook_mock, mock_tempfile + ): + wasb_blob_name = "test_file" + s3_acl_policy = "test policy" + s3_extra_args = {"ContentLanguage": "value"} + + wasb_hook_mock.return_value.get_blobs_list_recursive.return_value = [wasb_blob_name] + wasb_hook_mock.return_value.download.return_value = RawIOBase(b"testing") + mock_tempfile.return_value.__enter__.return_value.name = "test_temp_file" + + # with current S3_BUCKET url, parse_s3_url would complain + s3_hook_mock.parse_s3_url.return_value = ("bucket", wasb_blob_name) + mock_load_files = s3_hook_mock.return_value.load_file + + operator = AzureBlobStorageToS3Operator( + task_id=TASK_ID, + container_name=CONTAINER_NAME, + dest_s3_key=S3_BUCKET, + replace=False, + dest_s3_extra_args=s3_extra_args, + s3_acl_policy=s3_acl_policy, + ) + + operator.execute(None) + s3_hook_mock.assert_called_once_with(aws_conn_id="aws_default", extra_args=s3_extra_args, verify=None) + mock_load_files.assert_called_once_with( + filename="test_temp_file", + key=f"{S3_BUCKET}{wasb_blob_name}", + replace=False, + acl_policy=s3_acl_policy, + ) diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 464db0f39f43..672d20d3a65e 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -23,6 +23,7 @@ import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.storage.blob import BlobServiceClient +from azure.storage.blob._models import BlobProperties from airflow.exceptions import AirflowException from airflow.models import Connection @@ -299,6 +300,30 @@ def test_get_blobs_list(self, mock_service): name_starts_with="my", include=None, delimiter="/" ) + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") + def test_get_blobs_list_recursive(self, mock_service): + hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + hook.get_blobs_list_recursive( + container_name="mycontainer", prefix="test", include=None, endswith="file_extension" + ) + mock_service.return_value.get_container_client.assert_called_once_with("mycontainer") + mock_service.return_value.get_container_client.return_value.list_blobs.assert_called_once_with( + name_starts_with="test", include=None + ) + + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") + def test_get_blobs_list_recursive_endswith(self, mock_service): + hook = WasbHook(wasb_conn_id=self.shared_key_conn_id) + mock_service.return_value.get_container_client.return_value.list_blobs.return_value = [ + BlobProperties(name="test/abc.py"), + BlobProperties(name="test/inside_test/abc.py"), + BlobProperties(name="test/abc.csv"), + ] + blob_list_output = hook.get_blobs_list_recursive( + container_name="mycontainer", prefix="test", include=None, endswith=".py" + ) + assert blob_list_output == ["test/abc.py", "test/inside_test/abc.py"] + @pytest.mark.parametrize(argnames="create_container", argvalues=[True, False]) @mock.patch.object(WasbHook, "upload") def test_load_file(self, mock_upload, create_container): diff --git a/tests/system/providers/amazon/aws/example_azure_blob_to_s3.py b/tests/system/providers/amazon/aws/example_azure_blob_to_s3.py new file mode 100644 index 000000000000..79b49966ca54 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_azure_blob_to_s3.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator +from airflow.providers.amazon.aws.transfers.azure_blob_to_s3 import AzureBlobStorageToS3Operator +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder + +sys_test_context_task = SystemTestContextBuilder().build() + +DAG_ID = "example_azure_blob_to_s3" + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context["ENV_ID"] + + s3_bucket = f"{env_id}-azure_blob-to-s3-bucket" + s3_key = f"{env_id}-azure_blob-to-s3-key" + s3_key_url = f"s3://{s3_bucket}/{s3_key}" + azure_container_name = f"{env_id}-azure_blob-to-s3-container" + + create_s3_bucket = S3CreateBucketOperator(task_id="create_s3_bucket", bucket_name=s3_bucket) + + # [START howto_transfer_azure_blob_to_s3] + azure_blob_to_s3 = AzureBlobStorageToS3Operator( + task_id="azure_blob_to_s3", + container_name=azure_container_name, + dest_s3_key=s3_key_url, + ) + # [END howto_transfer_azure_blob_to_s3] + + delete_s3_bucket = S3DeleteBucketOperator( + task_id="delete_s3_bucket", + bucket_name=s3_bucket, + force_delete=True, + trigger_rule=TriggerRule.ALL_DONE, + ) + + chain( + # TEST SETUP + test_context, + create_s3_bucket, + # TEST BODY + azure_blob_to_s3, + # TEST TEARDOWN + delete_s3_bucket, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)