Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create_cluster(
for the cluster that is being created.
:param params: Remaining AWS Create cluster API params.
"""
response = self.get_conn().create_cluster(
response = self.conn.create_cluster(
ClusterIdentifier=cluster_identifier,
NodeType=node_type,
MasterUsername=master_username,
Expand All @@ -87,9 +87,9 @@ def cluster_status(self, cluster_identifier: str) -> str:
:param cluster_identifier: unique identifier of a cluster
"""
try:
response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
response = self.conn.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
return response[0]["ClusterStatus"] if response else None
except self.get_conn().exceptions.ClusterNotFoundFault:
except self.conn.exceptions.ClusterNotFoundFault:
return "cluster_not_found"

async def cluster_status_async(self, cluster_identifier: str) -> str:
Expand All @@ -115,7 +115,7 @@ def delete_cluster(
"""
final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""

response = self.get_conn().delete_cluster(
response = self.conn.delete_cluster(
ClusterIdentifier=cluster_identifier,
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
Expand All @@ -131,7 +131,7 @@ def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str] | Non

:param cluster_identifier: unique identifier of a cluster
"""
response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
response = self.conn.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
if "Snapshots" not in response:
return None
snapshots = response["Snapshots"]
Expand All @@ -149,7 +149,7 @@ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identi
:param cluster_identifier: unique identifier of a cluster
:param snapshot_identifier: unique identifier for a snapshot of a cluster
"""
response = self.get_conn().restore_from_cluster_snapshot(
response = self.conn.restore_from_cluster_snapshot(
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
)
return response["Cluster"] if response["Cluster"] else None
Expand All @@ -175,7 +175,7 @@ def create_cluster_snapshot(
"""
if tags is None:
tags = []
response = self.get_conn().create_cluster_snapshot(
response = self.conn.create_cluster_snapshot(
SnapshotIdentifier=snapshot_identifier,
ClusterIdentifier=cluster_identifier,
ManualSnapshotRetentionPeriod=retention_period,
Expand All @@ -192,11 +192,11 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str):
:param snapshot_identifier: A unique identifier for the snapshot that you are requesting
"""
try:
response = self.get_conn().describe_cluster_snapshots(
response = self.conn.describe_cluster_snapshots(
SnapshotIdentifier=snapshot_identifier,
)
snapshot = response.get("Snapshots")[0]
snapshot_status: str = snapshot.get("Status")
return snapshot_status
except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
except self.conn.exceptions.ClusterSnapshotNotFoundFault:
return None
Original file line number Diff line number Diff line change
Expand Up @@ -755,11 +755,18 @@ def execute(self, context: Context):
final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
)
break
except self.redshift_hook.get_conn().exceptions.InvalidClusterStateFault:
except self.redshift_hook.conn.exceptions.InvalidClusterStateFault:
self._attempts -= 1

if self._attempts:
self.log.error("Unable to delete cluster. %d attempts remaining.", self._attempts)
current_state = self.redshift_hook.conn.describe_clusters(
ClusterIdentifier=self.cluster_identifier
)["Clusters"][0]["ClusterStatus"]
self.log.error(
"Cluster in %s state, unable to delete. %d attempts remaining.",
current_state,
self._attempts,
)
time.sleep(self._attempt_interval)
else:
raise
Expand All @@ -785,7 +792,7 @@ def execute(self, context: Context):
)

elif self.wait_for_completion:
waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
waiter = self.redshift_hook.conn.get_waiter("cluster_deleted")
waiter.wait(
ClusterIdentifier=self.cluster_identifier,
WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": self.max_attempts},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _create_clusters():
def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self):
self._create_clusters()
hook = AwsBaseHook(aws_conn_id="aws_default", client_type="redshift")
client_from_hook = hook.get_conn()
client_from_hook = hook.conn

clusters = client_from_hook.describe_clusters()["Clusters"]
assert len(clusters) == 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_init(self):
assert redshift_operator.master_username == "adminuser"
assert redshift_operator.master_user_password == "Test123$"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_create_single_node_cluster(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_create_single_node_cluster(self, mock_conn):
redshift_operator = RedshiftCreateClusterOperator(
task_id="task_test",
cluster_identifier="test-cluster",
Expand All @@ -78,7 +78,7 @@ def test_create_single_node_cluster(self, mock_get_conn):
"PubliclyAccessible": True,
"Port": 5439,
}
mock_get_conn.return_value.create_cluster.assert_called_once_with(
mock_conn.create_cluster.assert_called_once_with(
ClusterIdentifier="test-cluster",
NodeType="dc2.large",
MasterUsername="adminuser",
Expand All @@ -87,12 +87,12 @@ def test_create_single_node_cluster(self, mock_get_conn):
)

# wait_for_completion is True so check waiter is called
mock_get_conn.return_value.get_waiter.return_value.wait.assert_called_once_with(
mock_conn.get_waiter.return_value.wait.assert_called_once_with(
ClusterIdentifier="test-cluster", WaiterConfig={"Delay": 60, "MaxAttempts": 5}
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_create_multi_node_cluster(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_create_multi_node_cluster(self, mock_conn):
redshift_operator = RedshiftCreateClusterOperator(
task_id="task_test",
cluster_identifier="test-cluster",
Expand All @@ -113,7 +113,7 @@ def test_create_multi_node_cluster(self, mock_get_conn):
"PubliclyAccessible": True,
"Port": 5439,
}
mock_get_conn.return_value.create_cluster.assert_called_once_with(
mock_conn.create_cluster.assert_called_once_with(
ClusterIdentifier="test-cluster",
NodeType="dc2.large",
MasterUsername="adminuser",
Expand All @@ -122,10 +122,10 @@ def test_create_multi_node_cluster(self, mock_get_conn):
)

# wait_for_completion is False so check waiter is not called
mock_get_conn.return_value.get_waiter.assert_not_called()
mock_conn.get_waiter.assert_not_called()

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_create_cluster_deferrable(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_create_cluster_deferrable(self, mock_conn):
redshift_operator = RedshiftCreateClusterOperator(
task_id="task_test",
cluster_identifier="test-cluster",
Expand Down Expand Up @@ -242,16 +242,16 @@ class TestRedshiftDeleteClusterSnapshotOperator:
@mock.patch(
"airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_cluster_snapshot_status"
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_delete_cluster_snapshot_wait(self, mock_get_conn, mock_get_cluster_snapshot_status):
@mock.patch.object(RedshiftHook, "conn")
def test_delete_cluster_snapshot_wait(self, mock_conn, mock_get_cluster_snapshot_status):
mock_get_cluster_snapshot_status.return_value = None
delete_snapshot = RedshiftDeleteClusterSnapshotOperator(
task_id="test_snapshot",
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
)
delete_snapshot.execute(None)
mock_get_conn.return_value.delete_cluster_snapshot.assert_called_once_with(
mock_conn.delete_cluster_snapshot.assert_called_once_with(
SnapshotClusterIdentifier="test_cluster",
SnapshotIdentifier="test_snapshot",
)
Expand All @@ -263,16 +263,16 @@ def test_delete_cluster_snapshot_wait(self, mock_get_conn, mock_get_cluster_snap
@mock.patch(
"airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_cluster_snapshot_status"
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_delete_cluster_snapshot(self, mock_get_conn, mock_get_cluster_snapshot_status):
@mock.patch.object(RedshiftHook, "conn")
def test_delete_cluster_snapshot(self, mock_conn, mock_get_cluster_snapshot_status):
delete_snapshot = RedshiftDeleteClusterSnapshotOperator(
task_id="test_snapshot",
cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
wait_for_completion=False,
)
delete_snapshot.execute(None)
mock_get_conn.return_value.delete_cluster_snapshot.assert_called_once_with(
mock_conn.delete_cluster_snapshot.assert_called_once_with(
SnapshotClusterIdentifier="test_cluster",
SnapshotIdentifier="test_snapshot",
)
Expand All @@ -298,13 +298,13 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch.object(RedshiftHook, "get_conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_resume_cluster_is_called_when_cluster_is_paused(self, mock_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
mock_conn.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
Expand Down Expand Up @@ -436,15 +436,15 @@ def test_init(self):
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_pause_cluster_is_called_when_cluster_is_available(self, mock_conn):
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
mock_conn.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand All @@ -462,7 +462,7 @@ def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
redshift_operator.execute(None)
assert mock_conn.pause_cluster.call_count == 3

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand All @@ -481,10 +481,10 @@ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
assert mock_conn.pause_cluster.call_count == 10

@mock.patch.object(RedshiftHook, "get_waiter")
@mock.patch.object(RedshiftHook, "get_conn")
def test_pause_cluster_wait_for_completion(self, mock_get_conn, mock_get_waiter):
@mock.patch.object(RedshiftHook, "conn")
def test_pause_cluster_wait_for_completion(self, mock_conn, mock_get_waiter):
"""Test Pause cluster operator with defer when deferrable param is true"""
mock_get_conn.return_value.pause_cluster.return_value = True
mock_conn.pause_cluster.return_value = True
waiter = Mock()
mock_get_waiter.return_value = waiter

Expand All @@ -497,10 +497,10 @@ def test_pause_cluster_wait_for_completion(self, mock_get_conn, mock_get_waiter)
waiter.wait.assert_called_once()

@mock.patch.object(RedshiftHook, "cluster_status")
@mock.patch.object(RedshiftHook, "get_conn")
def test_pause_cluster_deferrable_mode(self, mock_get_conn, mock_cluster_status):
@mock.patch.object(RedshiftHook, "conn")
def test_pause_cluster_deferrable_mode(self, mock_conn, mock_cluster_status):
"""Test Pause cluster operator with defer when deferrable param is true"""
mock_get_conn.return_value.pause_cluster.return_value = True
mock_conn.pause_cluster.return_value = True
mock_cluster_status.return_value = "available"

redshift_operator = RedshiftPauseClusterOperator(
Expand All @@ -516,12 +516,12 @@ def test_pause_cluster_deferrable_mode(self, mock_get_conn, mock_cluster_status)

@mock.patch("airflow.providers.amazon.aws.operators.redshift_cluster.RedshiftPauseClusterOperator.defer")
@mock.patch.object(RedshiftHook, "cluster_status")
@mock.patch.object(RedshiftHook, "get_conn")
@mock.patch.object(RedshiftHook, "conn")
def test_pause_cluster_deferrable_mode_in_deleting_status(
self, mock_get_conn, mock_cluster_status, mock_defer
self, mock_conn, mock_cluster_status, mock_defer
):
"""Test Pause cluster operator with defer when deferrable param is true"""
mock_get_conn.return_value.pause_cluster.return_value = True
mock_conn.pause_cluster.return_value = True
mock_cluster_status.return_value = "deleting"

redshift_operator = RedshiftPauseClusterOperator(
Expand Down Expand Up @@ -561,38 +561,38 @@ def test_template_fields(self):

class TestDeleteClusterOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_delete_cluster_with_wait_for_completion(self, mock_get_conn, mock_cluster_status):
@mock.patch.object(RedshiftHook, "conn")
def test_delete_cluster_with_wait_for_completion(self, mock_conn, mock_cluster_status):
mock_cluster_status.return_value = "cluster_not_found"
redshift_operator = RedshiftDeleteClusterOperator(
task_id="task_test", cluster_identifier="test_cluster", aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.delete_cluster.assert_called_once_with(
mock_conn.delete_cluster.assert_called_once_with(
ClusterIdentifier="test_cluster",
SkipFinalClusterSnapshot=True,
FinalClusterSnapshotIdentifier="",
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
def test_delete_cluster_without_wait_for_completion(self, mock_get_conn):
@mock.patch.object(RedshiftHook, "conn")
def test_delete_cluster_without_wait_for_completion(self, mock_conn):
redshift_operator = RedshiftDeleteClusterOperator(
task_id="task_test",
cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test",
wait_for_completion=False,
)
redshift_operator.execute(None)
mock_get_conn.return_value.delete_cluster.assert_called_once_with(
mock_conn.delete_cluster.assert_called_once_with(
ClusterIdentifier="test_cluster",
SkipFinalClusterSnapshot=True,
FinalClusterSnapshotIdentifier="",
)

mock_get_conn.return_value.cluster_status.assert_not_called()
mock_conn.cluster_status.assert_not_called()

@mock.patch.object(RedshiftHook, "delete_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_delete_cluster_multiple_attempts(self, _, mock_conn, mock_delete_cluster):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand All @@ -611,7 +611,7 @@ def test_delete_cluster_multiple_attempts(self, _, mock_conn, mock_delete_cluste
assert mock_delete_cluster.call_count == 3

@mock.patch.object(RedshiftHook, "delete_cluster")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
@mock.patch.object(RedshiftHook, "conn")
@mock.patch("time.sleep", return_value=None)
def test_delete_cluster_multiple_attempts_fail(self, _, mock_conn, mock_delete_cluster):
exception = boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
Expand Down