diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py index 904ed0a385170..660d6fa92e753 100644 --- a/airflow/providers/amazon/aws/hooks/base_aws.py +++ b/airflow/providers/amazon/aws/hooks/base_aws.py @@ -680,12 +680,16 @@ def async_conn(self): return self.get_client_type(region_name=self.region_name, deferrable=True) @cached_property - def conn_client_meta(self) -> ClientMeta: - """Get botocore client metadata from Hook connection (cached).""" + def _client(self) -> botocore.client.BaseClient: conn = self.conn if isinstance(conn, botocore.client.BaseClient): - return conn.meta - return conn.meta.client.meta + return conn + return conn.meta.client + + @property + def conn_client_meta(self) -> ClientMeta: + """Get botocore client metadata from Hook connection (cached).""" + return self._client.meta @property def conn_region_name(self) -> str: @@ -862,19 +866,9 @@ def get_waiter( if deferrable and not client: raise ValueError("client must be provided for a deferrable waiter.") - client = client or self.conn + # Currently, the custom waiter doesn't work with resource_type, only client_type is supported. + client = client or self._client if self.waiter_path and (waiter_name in self._list_custom_waiters()): - # Currently, the custom waiter doesn't work with resource_type, only client_type is supported. - if self.resource_type: - credentials = self.get_credentials() - client = boto3.client( - self.resource_type, - region_name=self.region_name, - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - ) - # Technically if waiter_name is in custom_waiters then self.waiter_path must # exist but MyPy doesn't like the fact that self.waiter_path could be None. with open(self.waiter_path) as config_file: @@ -909,7 +903,7 @@ def list_waiters(self) -> list[str]: return [*self._list_official_waiters(), *self._list_custom_waiters()] def _list_official_waiters(self) -> list[str]: - return self.conn.waiter_names + return self._client.waiter_names def _list_custom_waiters(self) -> list[str]: if not self.waiter_path: diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py b/tests/providers/amazon/aws/waiters/test_custom_waiters.py index d02c9c49e8fd0..21c051f3b4901 100644 --- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py +++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py @@ -26,6 +26,7 @@ from botocore.waiter import WaiterModel from moto import mock_eks +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, EcsTaskDefinitionStates from airflow.providers.amazon.aws.hooks.eks import EksHook @@ -73,6 +74,22 @@ def test_init(self): assert waiter.model.__getattribute__(attr) == expected_model.__getattribute__(attr) assert waiter.client == client_name + @pytest.mark.parametrize("boto_type", ["client", "resource"]) + def test_get_botocore_waiter(self, boto_type, monkeypatch): + kw = {f"{boto_type}_type": "s3"} + if boto_type == "client": + fake_client = boto3.client("s3", region_name="eu-west-3") + elif boto_type == "resource": + fake_client = boto3.resource("s3", region_name="eu-west-3") + else: + raise ValueError(f"Unexpected value {boto_type!r} for `boto_type`.") + monkeypatch.setattr(AwsBaseHook, "conn", fake_client) + + hook = AwsBaseHook(**kw) + with mock.patch("botocore.client.BaseClient.get_waiter") as m: + hook.get_waiter(waiter_name="FooBar") + m.assert_called_once_with("FooBar") + class TestCustomEKSServiceWaiters: def test_service_waiters(self): @@ -230,8 +247,9 @@ class TestCustomDynamoDBServiceWaiters: @pytest.fixture(autouse=True) def setup_test_cases(self, monkeypatch): - self.client = boto3.client("dynamodb", region_name="eu-west-3") - monkeypatch.setattr(DynamoDBHook, "conn", self.client) + self.resource = boto3.resource("dynamodb", region_name="eu-west-3") + monkeypatch.setattr(DynamoDBHook, "conn", self.resource) + self.client = self.resource.meta.client @pytest.fixture def mock_describe_export(self): @@ -253,16 +271,15 @@ def describe_export(status: str): def test_export_table_to_point_in_time_completed(self, mock_describe_export): """Test state transition from `in progress` to `completed` during init.""" - with mock.patch("boto3.client") as client: - client.return_value = self.client - waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client) - mock_describe_export.side_effect = [ - self.describe_export(self.STATUS_IN_PROGRESS), - self.describe_export(self.STATUS_COMPLETED), - ] - waiter.wait( - ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", - ) + waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table") + mock_describe_export.side_effect = [ + self.describe_export(self.STATUS_IN_PROGRESS), + self.describe_export(self.STATUS_COMPLETED), + ] + waiter.wait( + ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) def test_export_table_to_point_in_time_failed(self, mock_describe_export): """Test state transition from `in progress` to `failed` during init.""" @@ -274,4 +291,7 @@ def test_export_table_to_point_in_time_failed(self, mock_describe_export): ] waiter = DynamoDBHook(aws_conn_id=None).get_waiter("export_table", client=self.client) with pytest.raises(WaiterError, match='we matched expected path: "FAILED"'): - waiter.wait(ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry") + waiter.wait( + ExportArn="LoremIpsumissimplydummytextoftheprintingandtypesettingindustry", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + )