Skip to content

Commit

Permalink
Use boto3.client linked to resource meta instead of create new one …
Browse files Browse the repository at this point in the history
…for waiters (#33552)
  • Loading branch information
Taragolis authored Aug 21, 2023
1 parent b04be0f commit 8402e9a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
28 changes: 11 additions & 17 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,16 @@ def async_conn(self):
return self.get_client_type(region_name=self.region_name, deferrable=True)

@cached_property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
def _client(self) -> botocore.client.BaseClient:
conn = self.conn
if isinstance(conn, botocore.client.BaseClient):
return conn.meta
return conn.meta.client.meta
return conn
return conn.meta.client

@property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
return self._client.meta

@property
def conn_region_name(self) -> str:
Expand Down Expand Up @@ -862,19 +866,9 @@ def get_waiter(

if deferrable and not client:
raise ValueError("client must be provided for a deferrable waiter.")
client = client or self.conn
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
client = client or self._client
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
if self.resource_type:
credentials = self.get_credentials()
client = boto3.client(
self.resource_type,
region_name=self.region_name,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)

# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
Expand Down Expand Up @@ -909,7 +903,7 @@ def list_waiters(self) -> list[str]:
return [*self._list_official_waiters(), *self._list_custom_waiters()]

def _list_official_waiters(self) -> list[str]:
return self.conn.waiter_names
return self._client.waiter_names

def _list_custom_waiters(self) -> list[str]:
if not self.waiter_path:
Expand Down
46 changes: 33 additions & 13 deletions tests/providers/amazon/aws/waiters/test_custom_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from botocore.waiter import WaiterModel
from moto import mock_eks

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates
from airflow.providers.amazon.aws.hooks.eks import EksHook
Expand Down Expand Up @@ -73,6 +74,22 @@ def test_init(self):
assert waiter.model.__getattribute__(attr) == expected_model.__getattribute__(attr)
assert waiter.client == client_name

@pytest.mark.parametrize("boto_type", ["client", "resource"])
def test_get_botocore_waiter(self, boto_type, monkeypatch):
kw = {f"{boto_type}_type": "s3"}
if boto_type == "client":
fake_client = boto3.client("s3", region_name="eu-west-3")
elif boto_type == "resource":
fake_client = boto3.resource("s3", region_name="eu-west-3")
else:
raise ValueError(f"Unexpected value {boto_type!r} for `boto_type`.")
monkeypatch.setattr(AwsBaseHook, "conn", fake_client)

hook = AwsBaseHook(**kw)
with mock.patch("botocore.client.BaseClient.get_waiter") as m:
hook.get_waiter(waiter_name="FooBar")
m.assert_called_once_with("FooBar")


class TestCustomEKSServiceWaiters:
def test_service_waiters(self):
Expand Down Expand Up @@ -230,8 +247,9 @@ class TestCustomDynamoDBServiceWaiters:

@pytest.fixture(autouse=True)
def setup_test_cases(self, monkeypatch):
self.client = boto3.client("dynamodb", region_name="eu-west-3")
monkeypatch.setattr(DynamoDBHook, "conn", self.client)
self.resource = boto3.resource("dynamodb", region_name="eu-west-3")
monkeypatch.setattr(DynamoDBHook, "conn", self.resource)
self.client = self.resource.meta.client

@pytest.fixture
def mock_describe_export(self):
Expand All @@ -253,16 +271,15 @@ def describe_export(status: str):

def test_export_table_to_point_in_time_completed(self, mock_describe_export):
"""Test state transition from `in progress` to `completed` during init."""
with mock.patch("boto3.client") as client:
client.return_value = self.client
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client)
mock_describe_export.side_effect = [
self.describe_export(self.STATUS_IN_PROGRESS),
self.describe_export(self.STATUS_COMPLETED),
]
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
)
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table")
mock_describe_export.side_effect = [
self.describe_export(self.STATUS_IN_PROGRESS),
self.describe_export(self.STATUS_COMPLETED),
]
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
)

def test_export_table_to_point_in_time_failed(self, mock_describe_export):
"""Test state transition from `in progress` to `failed` during init."""
Expand All @@ -274,4 +291,7 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export):
]
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client)
with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'):
waiter.wait(ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry")
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
)

0 comments on commit 8402e9a

Please sign in to comment.