Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of a different AWS connection for DynamoDB #29452

Merged
merged 8 commits into from
Mar 9, 2023
46 changes: 35 additions & 11 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import json
import warnings
from copy import copy
from decimal import Decimal
from os.path import getsize
Expand All @@ -30,13 +31,20 @@
from uuid import uuid4

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.utils.context import Context


_DEPRECATION_MSG = (
"The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead."
)


class JSONEncoder(json.JSONEncoder):
"""Custom json encoder implementation"""

Expand All @@ -52,7 +60,10 @@ def _convert_item_to_json_bytes(item: dict[str, Any]) -> bytes:


def _upload_file_to_s3(
file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = "aws_default"
file_obj: IO,
bucket_name: str,
s3_key_prefix: str,
aws_conn_id: str | None = AwsBaseHook.default_conn_name,
) -> None:
s3_client = S3Hook(aws_conn_id=aws_conn_id).get_conn()
file_obj.seek(0)
Expand All @@ -78,16 +89,20 @@ class DynamoDBToS3Operator(BaseOperator):
:ref:`howto/transfer:DynamoDBToS3Operator`

:param dynamodb_table_name: Dynamodb table to replicate data from
:param source_aws_conn_id: The Airflow connection used for AWS credentials
to access DynamoDB. If this is None or empty then the default boto3
behaviour is used. If running Airflow in a distributed manner and
source_aws_conn_id is None or empty, then default boto3 configuration
would be used (and must be maintained on each worker node).
:param s3_bucket_name: S3 bucket to replicate data to
:param file_size: Flush file to s3 if file size >= file_size
:param dynamodb_scan_kwargs: kwargs pass to <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.scan>
:param s3_key_prefix: Prefix of s3 object key
:param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param dest_aws_conn_id: The Airflow connection used for AWS credentials
to access S3. If this is not set then the source_aws_conn_id connection is used.
:param aws_conn_id: The Airflow connection used for AWS credentials (deprecated).
Taragolis marked this conversation as resolved.
Show resolved Hide resolved

""" # noqa: E501

template_fields: Sequence[str] = (
Taragolis marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -103,12 +118,14 @@ def __init__(
self,
*,
dynamodb_table_name: str,
source_aws_conn_id: str | None = AwsBaseHook.default_conn_name,
s3_bucket_name: str,
file_size: int,
dynamodb_scan_kwargs: dict[str, Any] | None = None,
s3_key_prefix: str = "",
process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes,
aws_conn_id: str = "aws_default",
dest_aws_conn_id: str | None | ArgNotSet = NOTSET,
aws_conn_id: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -118,10 +135,17 @@ def __init__(
self.dynamodb_scan_kwargs = dynamodb_scan_kwargs
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = s3_key_prefix
self.aws_conn_id = aws_conn_id
if not isinstance(aws_conn_id, ArgNotSet):
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
self.source_aws_conn_id = aws_conn_id
else:
self.source_aws_conn_id = source_aws_conn_id
dym-ok marked this conversation as resolved.
Show resolved Hide resolved
self.dest_aws_conn_id = (
self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id
)

def execute(self, context: Context) -> None:
hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
hook = DynamoDBHook(aws_conn_id=self.source_aws_conn_id)
table = hook.get_conn().Table(self.dynamodb_table_name)

scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
Expand All @@ -135,7 +159,7 @@ def execute(self, context: Context) -> None:
raise e
finally:
if err is None:
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.dest_aws_conn_id)

def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
while True:
Expand All @@ -153,7 +177,7 @@ def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, tabl

# Upload the file to S3 if reach file size limit
if getsize(temp_file.name) >= self.file_size:
_upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
_upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix, self.dest_aws_conn_id)
temp_file.close()

temp_file = NamedTemporaryFile()
Expand Down
150 changes: 148 additions & 2 deletions tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

import pytest

from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator, JSONEncoder
from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import (
_DEPRECATION_MSG,
DynamoDBToS3Operator,
JSONEncoder,
)


class TestJSONEncoder:
Expand Down Expand Up @@ -107,6 +111,74 @@ def test_dynamodb_to_s3_success_with_decimal(self, mock_aws_dynamodb_hook, mock_

assert [{"a": float(a)}, {"b": float(b)}] == self.output_queue

@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
def test_dynamodb_to_s3_default_connection(self, mock_aws_dynamodb_hook, mock_s3_hook):
responses = [
{
"Items": [{"a": 1}, {"b": 2}],
"LastEvaluatedKey": "123",
},
{
"Items": [{"c": 3}],
},
]
table = MagicMock()
table.return_value.scan.side_effect = responses
mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table

s3_client = MagicMock()
s3_client.return_value.upload_file = self.mock_upload_file
mock_s3_hook.return_value.get_conn = s3_client

dynamodb_to_s3_operator = DynamoDBToS3Operator(
task_id="dynamodb_to_s3",
dynamodb_table_name="airflow_rocks",
s3_bucket_name="airflow-bucket",
file_size=4000,
)

dynamodb_to_s3_operator.execute(context={})
aws_conn_id = "aws_default"

mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)

@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
def test_dynamodb_to_s3_with_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
responses = [
{
"Items": [{"a": 1}, {"b": 2}],
"LastEvaluatedKey": "123",
},
{
"Items": [{"c": 3}],
},
]
table = MagicMock()
table.return_value.scan.side_effect = responses
mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table

s3_client = MagicMock()
s3_client.return_value.upload_file = self.mock_upload_file
mock_s3_hook.return_value.get_conn = s3_client

aws_conn_id = "test-conn-id"
with pytest.warns(DeprecationWarning, match=_DEPRECATION_MSG):
dynamodb_to_s3_operator = DynamoDBToS3Operator(
task_id="dynamodb_to_s3",
dynamodb_table_name="airflow_rocks",
s3_bucket_name="airflow-bucket",
file_size=4000,
aws_conn_id=aws_conn_id,
)

dynamodb_to_s3_operator.execute(context={})

mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)

@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
def test_dynamodb_to_s3_with_different_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
Expand All @@ -133,7 +205,7 @@ def test_dynamodb_to_s3_with_different_aws_conn_id(self, mock_aws_dynamodb_hook,
dynamodb_table_name="airflow_rocks",
s3_bucket_name="airflow-bucket",
file_size=4000,
aws_conn_id=aws_conn_id,
source_aws_conn_id=aws_conn_id,
)

dynamodb_to_s3_operator.execute(context={})
Expand All @@ -142,3 +214,77 @@ def test_dynamodb_to_s3_with_different_aws_conn_id(self, mock_aws_dynamodb_hook,

mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)

@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
def test_dynamodb_to_s3_with_two_different_connections(self, mock_aws_dynamodb_hook, mock_s3_hook):
responses = [
{
"Items": [{"a": 1}, {"b": 2}],
"LastEvaluatedKey": "123",
},
{
"Items": [{"c": 3}],
},
]
table = MagicMock()
table.return_value.scan.side_effect = responses
mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table

s3_client = MagicMock()
s3_client.return_value.upload_file = self.mock_upload_file
mock_s3_hook.return_value.get_conn = s3_client

s3_aws_conn_id = "test-conn-id"
dynamodb_conn_id = "test-dynamodb-conn-id"
dynamodb_to_s3_operator = DynamoDBToS3Operator(
task_id="dynamodb_to_s3",
dynamodb_table_name="airflow_rocks",
source_aws_conn_id=dynamodb_conn_id,
s3_bucket_name="airflow-bucket",
file_size=4000,
dest_aws_conn_id=s3_aws_conn_id,
)

dynamodb_to_s3_operator.execute(context={})

assert [{"a": 1}, {"b": 2}, {"c": 3}] == self.output_queue

mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id)
mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=dynamodb_conn_id)

@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
@patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
def test_dynamodb_to_s3_with_just_dest_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
responses = [
{
"Items": [{"a": 1}, {"b": 2}],
"LastEvaluatedKey": "123",
},
{
"Items": [{"c": 3}],
},
]
table = MagicMock()
table.return_value.scan.side_effect = responses
mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table

s3_client = MagicMock()
s3_client.return_value.upload_file = self.mock_upload_file
mock_s3_hook.return_value.get_conn = s3_client

s3_aws_conn_id = "test-conn-id"
dynamodb_to_s3_operator = DynamoDBToS3Operator(
task_id="dynamodb_to_s3",
dynamodb_table_name="airflow_rocks",
s3_bucket_name="airflow-bucket",
file_size=4000,
dest_aws_conn_id=s3_aws_conn_id,
)

dynamodb_to_s3_operator.execute(context={})

assert [{"a": 1}, {"b": 2}, {"c": 3}] == self.output_queue

mock_aws_dynamodb_hook.assert_called_with(aws_conn_id="aws_default")
mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id)