Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,12 +680,16 @@ def async_conn(self):
return self.get_client_type(region_name=self.region_name, deferrable=True)

@cached_property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
def _client(self) -> botocore.client.BaseClient:
conn = self.conn
if isinstance(conn, botocore.client.BaseClient):
return conn.meta
return conn.meta.client.meta
return conn
return conn.meta.client

@property
def conn_client_meta(self) -> ClientMeta:
"""Get botocore client metadata from Hook connection (cached)."""
return self._client.meta

@property
def conn_region_name(self) -> str:
Expand Down Expand Up @@ -862,19 +866,9 @@ def get_waiter(

if deferrable and not client:
raise ValueError("client must be provided for a deferrable waiter.")
client = client or self.conn
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
client = client or self._client
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
if self.resource_type:
credentials = self.get_credentials()
client = boto3.client(
self.resource_type,
region_name=self.region_name,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)

# Technically if waiter_name is in custom_waiters then self.waiter_path must
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
with open(self.waiter_path) as config_file:
Expand Down Expand Up @@ -909,7 +903,7 @@ def list_waiters(self) -> list[str]:
return [*self._list_official_waiters(), *self._list_custom_waiters()]

def _list_official_waiters(self) -> list[str]:
return self.conn.waiter_names
return self._client.waiter_names

def _list_custom_waiters(self) -> list[str]:
if not self.waiter_path:
Expand Down
46 changes: 33 additions & 13 deletions tests/providers/amazon/aws/waiters/test_custom_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from botocore.waiter import WaiterModel
from moto import mock_eks

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates
from airflow.providers.amazon.aws.hooks.eks import EksHook
Expand Down Expand Up @@ -73,6 +74,22 @@ def test_init(self):
assert waiter.model.__getattribute__(attr) == expected_model.__getattribute__(attr)
assert waiter.client == client_name

@pytest.mark.parametrize("boto_type", ["client", "resource"])
def test_get_botocore_waiter(self, boto_type, monkeypatch):
kw = {f"{boto_type}_type": "s3"}
if boto_type == "client":
fake_client = boto3.client("s3", region_name="eu-west-3")
elif boto_type == "resource":
fake_client = boto3.resource("s3", region_name="eu-west-3")
else:
raise ValueError(f"Unexpected value {boto_type!r} for `boto_type`.")
monkeypatch.setattr(AwsBaseHook, "conn", fake_client)

hook = AwsBaseHook(**kw)
with mock.patch("botocore.client.BaseClient.get_waiter") as m:
hook.get_waiter(waiter_name="FooBar")
m.assert_called_once_with("FooBar")


class TestCustomEKSServiceWaiters:
def test_service_waiters(self):
Expand Down Expand Up @@ -230,8 +247,9 @@ class TestCustomDynamoDBServiceWaiters:

@pytest.fixture(autouse=True)
def setup_test_cases(self, monkeypatch):
self.client = boto3.client("dynamodb", region_name="eu-west-3")
monkeypatch.setattr(DynamoDBHook, "conn", self.client)
self.resource = boto3.resource("dynamodb", region_name="eu-west-3")
monkeypatch.setattr(DynamoDBHook, "conn", self.resource)
self.client = self.resource.meta.client

@pytest.fixture
def mock_describe_export(self):
Expand All @@ -253,16 +271,15 @@ def describe_export(status: str):

def test_export_table_to_point_in_time_completed(self, mock_describe_export):
"""Test state transition from `in progress` to `completed` during init."""
with mock.patch("boto3.client") as client:
client.return_value = self.client
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client)
mock_describe_export.side_effect = [
self.describe_export(self.STATUS_IN_PROGRESS),
self.describe_export(self.STATUS_COMPLETED),
]
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
)
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table")
mock_describe_export.side_effect = [
self.describe_export(self.STATUS_IN_PROGRESS),
self.describe_export(self.STATUS_COMPLETED),
]
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
)

def test_export_table_to_point_in_time_failed(self, mock_describe_export):
"""Test state transition from `in progress` to `failed` during init."""
Expand All @@ -274,4 +291,7 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export):
]
waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client)
with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'):
waiter.wait(ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry")
waiter.wait(
ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry",
WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
)