Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions airflow/providers/amazon/aws/example_dags/example_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
This is an example dag for using `RedshiftSQLOperator` to authenticate with Amazon Redshift
then execute a simple select statement
"""

from datetime import datetime

Expand Down Expand Up @@ -81,6 +77,7 @@
export_task_identifier='export-auth-db-snap-{{ ds }}',
source_arn='arn:aws:rds:<region>:<account number>:snapshot:auth-db-snap',
s3_bucket_name='my_s3_bucket',
s3_prefix='some/prefix',
iam_role_arn='arn:aws:iam:<region>:<account number>:role/MyRole',
kms_key_id='arn:aws:kms:<region>:<account number>:key/*****-****-****-****-********',
aws_conn_id='aws_default',
Expand All @@ -105,7 +102,7 @@
# [START howto_guide_rds_create_subscription]
create_subscription = RdsCreateEventSubscriptionOperator(
task_id='create_subscription',
subscription_name='my_topic_subscription',
subscription_name='my-topic-subscription',
sns_topic_arn='arn:aws:sns:<region>:<account number>:MyTopic',
source_type='db-instance',
source_ids=['auth-db'],
Expand All @@ -118,7 +115,7 @@
# [START howto_guide_rds_delete_subscription]
delete_subscription = RdsDeleteEventSubscriptionOperator(
task_id='delete_subscription',
subscription_name='my_topic_subscription',
subscription_name='my-topic-subscription',
aws_conn_id='aws_default',
hook_params={'region_name': 'us-east-1'},
)
Expand All @@ -144,6 +141,7 @@
export_sensor = RdsExportTaskExistenceSensor(
task_id='export_sensor',
export_task_identifier='export-auth-db-snap-{{ ds }}',
target_statuses=['starting', 'in_progress', 'complete', 'canceling', 'canceled'],
aws_conn_id='aws_default',
hook_params={'region_name': 'us-east-1'},
)
Expand Down
20 changes: 11 additions & 9 deletions airflow/providers/amazon/aws/operators/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,15 @@ def _await_status(
if len(items) > 1:
raise AirflowException(f"There are {len(items)} {item_type} with identifier {item_name}")

if wait_statuses and items[0]['Status'] in wait_statuses:
if wait_statuses and items[0]['Status'].lower() in wait_statuses:
time.sleep(self._await_interval)
continue
elif ok_statuses and items[0]['Status'] in ok_statuses:
elif ok_statuses and items[0]['Status'].lower() in ok_statuses:
break
elif error_statuses and items[0]['Status'] in error_statuses:
elif error_statuses and items[0]['Status'].lower() in error_statuses:
raise AirflowException(f"Item has error status ({error_statuses}): {items[0]}")

time.sleep(self._await_interval)
else:
raise AirflowException(f"Item has uncertain status: {items[0]}")

return None

Expand Down Expand Up @@ -118,7 +119,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
`USER Tagging <https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_Tagging.html>`__
"""

template_fields = ("db_snapshot_identifier", "db_instance_identifier", "tags")
template_fields = ("db_snapshot_identifier", "db_identifier", "tags")

def __init__(
self,
Expand Down Expand Up @@ -257,7 +258,7 @@ def execute(self, context: 'Context') -> str:
self._await_status(
'instance_snapshot',
self.target_db_snapshot_identifier,
wait_statuses=['copying'],
wait_statuses=['creating'],
ok_statuses=['available'],
)
else:
Expand Down Expand Up @@ -392,7 +393,8 @@ def execute(self, context: 'Context') -> str:
'export_task',
self.export_task_identifier,
wait_statuses=['starting', 'in_progress'],
ok_statuses=['available', 'complete'],
ok_statuses=['complete'],
error_statuses=['canceling', 'canceled'],
)

return json.dumps(start_export, default=str)
Expand Down Expand Up @@ -506,7 +508,7 @@ def execute(self, context: 'Context') -> str:
'event_subscription',
self.subscription_name,
wait_statuses=['creating'],
ok_statuses=['created', 'available'],
ok_statuses=['active'],
)

return json.dumps(create_subscription, default=str)
Expand Down
14 changes: 10 additions & 4 deletions airflow/providers/amazon/aws/sensors/rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _check_item(self, item_type: str, item_name: str) -> bool:
except ClientError:
return False
else:
return bool(items) and any(map(lambda s: items[0]['Status'] == s, self.target_statuses))
return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses))


class RdsSnapshotExistenceSensor(RdsBaseSensor):
Expand All @@ -80,7 +80,7 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):

template_fields: Sequence[str] = (
'db_snapshot_identifier',
'target_status',
'target_statuses',
)

def __init__(
Expand Down Expand Up @@ -121,7 +121,7 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):

template_fields: Sequence[str] = (
'export_task_identifier',
'target_status',
'target_statuses',
)

def __init__(
Expand All @@ -135,7 +135,13 @@ def __init__(
super().__init__(aws_conn_id=aws_conn_id, **kwargs)

self.export_task_identifier = export_task_identifier
self.target_statuses = target_statuses or ['available']
self.target_statuses = target_statuses or [
'starting',
'in_progress',
'complete',
'canceling',
'canceled',
]

def poke(self, context: 'Context'):
self.log.info(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def write_version(filename: str = os.path.join(*[my_dir, "airflow", "git_version
'jira',
'jsondiff',
'mongomock',
'moto>=3.0.7',
'moto>=3.1.0',
'parameterized',
'paramiko',
'pipdeptree',
Expand Down
42 changes: 21 additions & 21 deletions tests/providers/amazon/aws/operators/test_rds.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
from airflow.utils import timezone

try:
from moto import mock_rds2
from moto import mock_rds
except ImportError:
mock_rds2 = None
mock_rds = None


DEFAULT_DATE = timezone.datetime(2019, 1, 1)
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_await_status_ok(self):
)


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsCreateDbSnapshotOperator:
@classmethod
def setup_class(cls):
Expand All @@ -177,7 +177,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_create_db_instance_snapshot(self):
_create_db_instance(self.hook)
instance_snapshot_operator = RdsCreateDbSnapshotOperator(
Expand All @@ -196,7 +196,7 @@ def test_create_db_instance_snapshot(self):
assert instance_snapshots
assert len(instance_snapshots) == 1

@mock_rds2
@mock_rds
def test_create_db_cluster_snapshot(self):
_create_db_cluster(self.hook)
cluster_snapshot_operator = RdsCreateDbSnapshotOperator(
Expand All @@ -216,7 +216,7 @@ def test_create_db_cluster_snapshot(self):
assert len(cluster_snapshots) == 1


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsCopyDbSnapshotOperator:
@classmethod
def setup_class(cls):
Expand All @@ -228,7 +228,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_copy_db_instance_snapshot(self):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
Expand All @@ -248,7 +248,7 @@ def test_copy_db_instance_snapshot(self):
assert instance_snapshots
assert len(instance_snapshots) == 1

@mock_rds2
@mock_rds
def test_copy_db_cluster_snapshot(self):
_create_db_cluster(self.hook)
_create_db_cluster_snapshot(self.hook)
Expand All @@ -271,7 +271,7 @@ def test_copy_db_cluster_snapshot(self):
assert len(cluster_snapshots) == 1


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsDeleteDbSnapshotOperator:
@classmethod
def setup_class(cls):
Expand All @@ -283,7 +283,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_delete_db_instance_snapshot(self):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
Expand All @@ -300,7 +300,7 @@ def test_delete_db_instance_snapshot(self):
with pytest.raises(self.hook.conn.exceptions.ClientError):
self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)

@mock_rds2
@mock_rds
def test_delete_db_cluster_snapshot(self):
_create_db_cluster(self.hook)
_create_db_cluster_snapshot(self.hook)
Expand All @@ -318,7 +318,7 @@ def test_delete_db_cluster_snapshot(self):
self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=DB_CLUSTER_SNAPSHOT)


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsStartExportTaskOperator:
@classmethod
def setup_class(cls):
Expand All @@ -330,7 +330,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_start_export_task(self):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
Expand All @@ -352,10 +352,10 @@ def test_start_export_task(self):

assert export_tasks
assert len(export_tasks) == 1
assert export_tasks[0]['Status'] == 'available'
assert export_tasks[0]['Status'] == 'complete'


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsCancelExportTaskOperator:
@classmethod
def setup_class(cls):
Expand All @@ -367,7 +367,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_cancel_export_task(self):
_create_db_instance(self.hook)
_create_db_instance_snapshot(self.hook)
Expand All @@ -389,7 +389,7 @@ def test_cancel_export_task(self):
assert export_tasks[0]['Status'] == 'canceled'


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsCreateEventSubscriptionOperator:
@classmethod
def setup_class(cls):
Expand All @@ -401,7 +401,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_create_event_subscription(self):
_create_db_instance(self.hook)

Expand All @@ -421,10 +421,10 @@ def test_create_event_subscription(self):

assert subscriptions
assert len(subscriptions) == 1
assert subscriptions[0]['Status'] == 'available'
assert subscriptions[0]['Status'] == 'active'


@pytest.mark.skipif(mock_rds2 is None, reason='mock_rds2 package not present')
@pytest.mark.skipif(mock_rds is None, reason='mock_rds package not present')
class TestRdsDeleteEventSubscriptionOperator:
@classmethod
def setup_class(cls):
Expand All @@ -436,7 +436,7 @@ def teardown_class(cls):
del cls.dag
del cls.hook

@mock_rds2
@mock_rds
def test_delete_event_subscription(self):
_create_event_subscription(self.hook)

Expand Down
Loading