diff --git a/airflow/providers/amazon/aws/example_dags/example_rds.py b/airflow/providers/amazon/aws/example_dags/example_rds.py index 5cf6d85e42646..f30404b9d6867 100644 --- a/airflow/providers/amazon/aws/example_dags/example_rds.py +++ b/airflow/providers/amazon/aws/example_dags/example_rds.py @@ -15,10 +15,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -This is an example dag for using `RedshiftSQLOperator` to authenticate with Amazon Redshift -then execute a simple select statement -""" from datetime import datetime @@ -81,6 +77,7 @@ export_task_identifier='export-auth-db-snap-{{ ds }}', source_arn='arn:aws:rds:::snapshot:auth-db-snap', s3_bucket_name='my_s3_bucket', + s3_prefix='some/prefix', iam_role_arn='arn:aws:iam:::role/MyRole', kms_key_id='arn:aws:kms:::key/*****-****-****-****-********', aws_conn_id='aws_default', @@ -105,7 +102,7 @@ # [START howto_guide_rds_create_subscription] create_subscription = RdsCreateEventSubscriptionOperator( task_id='create_subscription', - subscription_name='my_topic_subscription', + subscription_name='my-topic-subscription', sns_topic_arn='arn:aws:sns:::MyTopic', source_type='db-instance', source_ids=['auth-db'], @@ -118,7 +115,7 @@ # [START howto_guide_rds_delete_subscription] delete_subscription = RdsDeleteEventSubscriptionOperator( task_id='delete_subscription', - subscription_name='my_topic_subscription', + subscription_name='my-topic-subscription', aws_conn_id='aws_default', hook_params={'region_name': 'us-east-1'}, ) @@ -144,6 +141,7 @@ export_sensor = RdsExportTaskExistenceSensor( task_id='export_sensor', export_task_identifier='export-auth-db-snap-{{ ds }}', + target_statuses=['starting', 'in_progress', 'complete', 'canceling', 'canceled'], aws_conn_id='aws_default', hook_params={'region_name': 'us-east-1'}, ) diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py index e14df928dcbad..a527107e80a46 100644 --- a/airflow/providers/amazon/aws/operators/rds.py +++ b/airflow/providers/amazon/aws/operators/rds.py @@ -82,14 +82,15 @@ def _await_status( if len(items) > 1: raise AirflowException(f"There are {len(items)} {item_type} with identifier {item_name}") - if wait_statuses and items[0]['Status'] in wait_statuses: + if wait_statuses and items[0]['Status'].lower() in wait_statuses: + time.sleep(self._await_interval) continue - elif ok_statuses and items[0]['Status'] in ok_statuses: + elif ok_statuses and items[0]['Status'].lower() in ok_statuses: break - elif error_statuses and items[0]['Status'] in error_statuses: + elif error_statuses and items[0]['Status'].lower() in error_statuses: raise AirflowException(f"Item has error status ({error_statuses}): {items[0]}") - - time.sleep(self._await_interval) + else: + raise AirflowException(f"Item has uncertain status: {items[0]}") return None @@ -118,7 +119,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator): `USER Tagging `__ """ - template_fields = ("db_snapshot_identifier", "db_instance_identifier", "tags") + template_fields = ("db_snapshot_identifier", "db_identifier", "tags") def __init__( self, @@ -257,7 +258,7 @@ def execute(self, context: 'Context') -> str: self._await_status( 'instance_snapshot', self.target_db_snapshot_identifier, - wait_statuses=['copying'], + wait_statuses=['creating'], ok_statuses=['available'], ) else: @@ -392,7 +393,8 @@ def execute(self, context: 'Context') -> str: 'export_task', self.export_task_identifier, wait_statuses=['starting', 'in_progress'], - ok_statuses=['available', 'complete'], + ok_statuses=['complete'], + error_statuses=['canceling', 'canceled'], ) return json.dumps(start_export, default=str) @@ -506,7 +508,7 @@ def execute(self, context: 'Context') -> str: 'event_subscription', self.subscription_name, wait_statuses=['creating'], - ok_statuses=['created', 'available'], + ok_statuses=['active'], ) return json.dumps(create_subscription, default=str) diff --git a/airflow/providers/amazon/aws/sensors/rds.py b/airflow/providers/amazon/aws/sensors/rds.py index da57e7030a29a..1c74d5ae8fc14 100644 --- a/airflow/providers/amazon/aws/sensors/rds.py +++ b/airflow/providers/amazon/aws/sensors/rds.py @@ -62,7 +62,7 @@ def _check_item(self, item_type: str, item_name: str) -> bool: except ClientError: return False else: - return bool(items) and any(map(lambda s: items[0]['Status'] == s, self.target_statuses)) + return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses)) class RdsSnapshotExistenceSensor(RdsBaseSensor): @@ -80,7 +80,7 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor): template_fields: Sequence[str] = ( 'db_snapshot_identifier', - 'target_status', + 'target_statuses', ) def __init__( @@ -121,7 +121,7 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor): template_fields: Sequence[str] = ( 'export_task_identifier', - 'target_status', + 'target_statuses', ) def __init__( @@ -135,7 +135,13 @@ def __init__( super().__init__(aws_conn_id=aws_conn_id, **kwargs) self.export_task_identifier = export_task_identifier - self.target_statuses = target_statuses or ['available'] + self.target_statuses = target_statuses or [ + 'starting', + 'in_progress', + 'complete', + 'canceling', + 'canceled', + ] def poke(self, context: 'Context'): self.log.info( diff --git a/setup.py b/setup.py index 2e4f3ef72fd69..667819e1fc521 100644 --- a/setup.py +++ b/setup.py @@ -612,7 +612,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version 'jira', 'jsondiff', 'mongomock', - 'moto>=3.0.7', + 'moto>=3.1.0', 'parameterized', 'paramiko', 'pipdeptree', diff --git a/tests/providers/amazon/aws/operators/test_rds.py b/tests/providers/amazon/aws/operators/test_rds.py index 0989736ff7981..d952fbc11a93c 100644 --- a/tests/providers/amazon/aws/operators/test_rds.py +++ b/tests/providers/amazon/aws/operators/test_rds.py @@ -35,9 +35,9 @@ from airflow.utils import timezone try: - from moto import mock_rds2 + from moto import mock_rds except ImportError: - mock_rds2 = None + mock_rds = None DEFAULT_DATE = timezone.datetime(2019, 1, 1) @@ -165,7 +165,7 @@ def test_await_status_ok(self): ) -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsCreateDbSnapshotOperator: @classmethod def setup_class(cls): @@ -177,7 +177,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_create_db_instance_snapshot(self): _create_db_instance(self.hook) instance_snapshot_operator = RdsCreateDbSnapshotOperator( @@ -196,7 +196,7 @@ def test_create_db_instance_snapshot(self): assert instance_snapshots assert len(instance_snapshots) == 1 - @mock_rds2 + @mock_rds def test_create_db_cluster_snapshot(self): _create_db_cluster(self.hook) cluster_snapshot_operator = RdsCreateDbSnapshotOperator( @@ -216,7 +216,7 @@ def test_create_db_cluster_snapshot(self): assert len(cluster_snapshots) == 1 -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsCopyDbSnapshotOperator: @classmethod def setup_class(cls): @@ -228,7 +228,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_copy_db_instance_snapshot(self): _create_db_instance(self.hook) _create_db_instance_snapshot(self.hook) @@ -248,7 +248,7 @@ def test_copy_db_instance_snapshot(self): assert instance_snapshots assert len(instance_snapshots) == 1 - @mock_rds2 + @mock_rds def test_copy_db_cluster_snapshot(self): _create_db_cluster(self.hook) _create_db_cluster_snapshot(self.hook) @@ -271,7 +271,7 @@ def test_copy_db_cluster_snapshot(self): assert len(cluster_snapshots) == 1 -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsDeleteDbSnapshotOperator: @classmethod def setup_class(cls): @@ -283,7 +283,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_delete_db_instance_snapshot(self): _create_db_instance(self.hook) _create_db_instance_snapshot(self.hook) @@ -300,7 +300,7 @@ def test_delete_db_instance_snapshot(self): with pytest.raises(self.hook.conn.exceptions.ClientError): self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_CLUSTER_SNAPSHOT) - @mock_rds2 + @mock_rds def test_delete_db_cluster_snapshot(self): _create_db_cluster(self.hook) _create_db_cluster_snapshot(self.hook) @@ -318,7 +318,7 @@ def test_delete_db_cluster_snapshot(self): self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT) -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsStartExportTaskOperator: @classmethod def setup_class(cls): @@ -330,7 +330,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_start_export_task(self): _create_db_instance(self.hook) _create_db_instance_snapshot(self.hook) @@ -352,10 +352,10 @@ def test_start_export_task(self): assert export_tasks assert len(export_tasks) == 1 - assert export_tasks[0]['Status'] == 'available' + assert export_tasks[0]['Status'] == 'complete' -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsCancelExportTaskOperator: @classmethod def setup_class(cls): @@ -367,7 +367,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_cancel_export_task(self): _create_db_instance(self.hook) _create_db_instance_snapshot(self.hook) @@ -389,7 +389,7 @@ def test_cancel_export_task(self): assert export_tasks[0]['Status'] == 'canceled' -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsCreateEventSubscriptionOperator: @classmethod def setup_class(cls): @@ -401,7 +401,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_create_event_subscription(self): _create_db_instance(self.hook) @@ -421,10 +421,10 @@ def test_create_event_subscription(self): assert subscriptions assert len(subscriptions) == 1 - assert subscriptions[0]['Status'] == 'available' + assert subscriptions[0]['Status'] == 'active' -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsDeleteEventSubscriptionOperator: @classmethod def setup_class(cls): @@ -436,7 +436,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_delete_event_subscription(self): _create_event_subscription(self.hook) diff --git a/tests/providers/amazon/aws/sensors/test_rds.py b/tests/providers/amazon/aws/sensors/test_rds.py index 1dbeeaea85eec..f93b93174eded 100644 --- a/tests/providers/amazon/aws/sensors/test_rds.py +++ b/tests/providers/amazon/aws/sensors/test_rds.py @@ -28,9 +28,9 @@ from airflow.utils import timezone try: - from moto import mock_rds2 + from moto import mock_rds except ImportError: - mock_rds2 = None + mock_rds = None DEFAULT_DATE = timezone.datetime(2019, 1, 1) @@ -132,7 +132,7 @@ def test_check_item_false(self): assert not self.base_sensor._check_item(item_type='instance_snapshot', item_name='') -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsSnapshotExistenceSensor: @classmethod def setup_class(cls): @@ -144,7 +144,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_db_instance_snapshot_poke_true(self): _create_db_instance_snapshot(self.hook) op = RdsSnapshotExistenceSensor( @@ -156,7 +156,7 @@ def test_db_instance_snapshot_poke_true(self): ) assert op.poke(None) - @mock_rds2 + @mock_rds def test_db_instance_snapshot_poke_false(self): op = RdsSnapshotExistenceSensor( task_id='test_instance_snap_false', @@ -167,7 +167,7 @@ def test_db_instance_snapshot_poke_false(self): ) assert not op.poke(None) - @mock_rds2 + @mock_rds def test_db_instance_cluster_poke_true(self): _create_db_cluster_snapshot(self.hook) op = RdsSnapshotExistenceSensor( @@ -179,7 +179,7 @@ def test_db_instance_cluster_poke_true(self): ) assert op.poke(None) - @mock_rds2 + @mock_rds def test_db_instance_cluster_poke_false(self): op = RdsSnapshotExistenceSensor( task_id='test_cluster_snap_false', @@ -191,7 +191,7 @@ def test_db_instance_cluster_poke_false(self): assert not op.poke(None) -@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present') +@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present') class TestRdsExportTaskExistenceSensor: @classmethod def setup_class(cls): @@ -203,7 +203,7 @@ def teardown_class(cls): del cls.dag del cls.hook - @mock_rds2 + @mock_rds def test_export_task_poke_true(self): _create_db_instance_snapshot(self.hook) _start_export_task(self.hook) @@ -215,7 +215,7 @@ def test_export_task_poke_true(self): ) assert op.poke(None) - @mock_rds2 + @mock_rds def test_export_task_poke_false(self): _create_db_instance_snapshot(self.hook) op = RdsExportTaskExistenceSensor(