Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions tests/providers/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.amazon.aws.hooks.athena import AthenaHook
Expand Down Expand Up @@ -48,8 +47,8 @@
}


class TestAthenaHook(unittest.TestCase):
def setUp(self):
class TestAthenaHook:
def setup_method(self):
self.athena = AthenaHook(sleep_time=0)

def test_init(self):
Expand Down Expand Up @@ -175,7 +174,3 @@ def test_hook_get_output_location(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT
result = self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
assert result == "s3://test_bucket/test.csv"


if __name__ == "__main__":
unittest.main()
10 changes: 3 additions & 7 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import json
import os
import unittest
from base64 import b64encode
from datetime import datetime, timedelta, timezone
from unittest import mock
Expand Down Expand Up @@ -606,7 +605,7 @@ def mock_refresh_credentials():
def test_connection_region_name(
self, conn_type, connection_uri, region_name, env_region, expected_region_name
):
with unittest.mock.patch.dict(
with mock.patch.dict(
"os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region
):
if conn_type == "client":
Expand All @@ -629,10 +628,7 @@ def test_connection_region_name(
],
)
def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition):
with unittest.mock.patch.dict(
"os.environ",
AIRFLOW_CONN_TEST_CONN=connection_uri,
):
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri):
if conn_type == "client":
hook = AwsBaseHook(aws_conn_id="test_conn", client_type="dynamodb")
elif conn_type == "resource":
Expand Down Expand Up @@ -772,7 +768,7 @@ def test_resolve_verify(self, verify, conn_verify):
extra={"verify": conn_verify} if conn_verify is not None else {},
)

with unittest.mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()):
hook = AwsBaseHook(aws_conn_id="test_conn", verify=verify)
expected = verify if verify is not None else conn_verify
assert hook.verify == expected
Expand Down
11 changes: 5 additions & 6 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,11 @@ def test_job_no_awslogs_stream(self, caplog):
}
]
}
with caplog.at_level(level=logging.getLevelName("WARNING")):

with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "doesn't create AWS CloudWatch Stream" in log_record.message
assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0]

def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
Expand All @@ -304,11 +304,10 @@ def test_job_splunk_logs(self, caplog):
}
]
}
with caplog.at_level(level=logging.getLevelName("WARNING")):
with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
log_record = caplog.records[0]
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0]


class TestBatchClientDelays:
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/hooks/test_batch_waiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from __future__ import annotations

import inspect
import unittest
from typing import NamedTuple
from unittest import mock

Expand Down Expand Up @@ -317,12 +316,12 @@ def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definiti
assert job_status == "SUCCEEDED"


class TestBatchWaiters(unittest.TestCase):
class TestBatchWaiters:
@mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION)
@mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID)
@mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY)
@mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type")
def setUp(self, get_client_type_mock):
def setup_method(self, method, get_client_type_mock):
Comment on lines -325 to +324
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original @vandonr-amz

and nit: I like using _ as name for the arguments that are unused, it's a clear way to mark them as "only here for the compiler to be happy"

This mostly as reminder what the actual argument here.

self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19"
self.region_name = AWS_REGION

Expand Down
25 changes: 8 additions & 17 deletions tests/providers/amazon/aws/hooks/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

import boto3
Expand All @@ -29,7 +28,7 @@


@mock_datasync
class TestDataSyncHook(unittest.TestCase):
class TestDataSyncHook:
def test_get_conn(self):
hook = DataSyncHook(aws_conn_id="aws_default")
assert hook.get_conn() is not None
Expand All @@ -50,21 +49,13 @@ def test_get_conn(self):

@mock_datasync
@mock.patch.object(DataSyncHook, "get_conn")
class TestDataSyncHookMocked(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.source_server_hostname = "host"
self.source_subdirectory = "somewhere"
self.destination_bucket_name = "my_bucket"
self.destination_bucket_dir = "dir"
class TestDataSyncHookMocked:
source_server_hostname = "host"
source_subdirectory = "somewhere"
destination_bucket_name = "my_bucket"
destination_bucket_dir = "dir"

self.client = None
self.hook = None
self.source_location_arn = None
self.destination_location_arn = None
self.task_arn = None

def setUp(self):
def setup_method(self, method):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original @vandonr-amz

Ah yes I imagine because of the order of arguments, but I imagine in test_datasync it's not needed ?

Actually this arg needed because we decorated entire class

test setup failed
args = (<tests.providers.amazon.aws.hooks.test_datasync.TestDataSyncHookMocked object at 0x109fdf4c0>, <bound method TestData...HookMocked.test_init of <tests.providers.amazon.aws.hooks.test_datasync.TestDataSyncHookMocked object at 0x109fdf4c0>>)
kwargs = {}

    def wrapper(*args, **kwargs):
        self.start(reset=reset)
        try:
>           result = func(*args, **kwargs)
E           TypeError: setup_method() takes 1 positional argument but 2 were given

self.client = boto3.client("datasync", region_name="us-east-1")
self.hook = DataSyncHook(aws_conn_id="aws_default", wait_interval_seconds=0)

Expand All @@ -86,7 +77,7 @@ def setUp(self):
DestinationLocationArn=self.destination_location_arn,
)["TaskArn"]

def tearDown(self):
def teardown_method(self, method):
# Delete all tasks:
tasks = self.client.list_tasks()
for task in tasks["Tasks"]:
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/hooks/test_dms_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import json
import unittest
from typing import Any
from unittest import mock

Expand Down Expand Up @@ -68,8 +67,8 @@
MOCK_DELETE_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_DATA, "Status": "deleting"}}


class TestDmsHook(unittest.TestCase):
def setUp(self):
class TestDmsHook:
def setup_method(self):
self.dms = DmsHook()

def test_init(self):
Expand Down
5 changes: 2 additions & 3 deletions tests/providers/amazon/aws/hooks/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
# under the License.
from __future__ import annotations

import unittest
import uuid

from moto import mock_dynamodb

from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook


class TestDynamoDBHook(unittest.TestCase):
class TestDynamoDBHook:
@mock_dynamodb
def test_get_conn_returns_a_boto3_connection(self):
hook = DynamoDBHook(aws_conn_id="aws_default")
Expand All @@ -39,7 +38,7 @@ def test_insert_batch_items_dynamodb_table(self):
)

# this table needs to be created in production
table = hook.get_conn().create_table(
hook.get_conn().create_table(
TableName="test_airflow",
KeySchema=[
{"AttributeName": "id", "KeyType": "HASH"},
Expand Down
35 changes: 15 additions & 20 deletions tests/providers/amazon/aws/hooks/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import unittest
from datetime import timedelta
from unittest import mock

Expand Down Expand Up @@ -56,38 +55,34 @@ def test_get_task_state(self, mock_conn) -> None:
assert EcsHook().get_task_state(cluster="cluster_name", task="task_name") == "ACTIVE"


class TestShouldRetry(unittest.TestCase):
class TestShouldRetry:
def test_return_true_on_valid_reason(self):
self.assertTrue(should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo")))
assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo"))

def test_return_false_on_invalid_reason(self):
self.assertFalse(should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo")))
assert not should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo"))


class TestShouldRetryEni(unittest.TestCase):
class TestShouldRetryEni:
def test_return_true_on_valid_reason(self):
self.assertTrue(
should_retry_eni(
EcsTaskFailToStart(
"The task failed to start due to: "
"Timeout waiting for network interface provisioning to complete."
)
assert should_retry_eni(
EcsTaskFailToStart(
"The task failed to start due to: "
"Timeout waiting for network interface provisioning to complete."
)
)

def test_return_false_on_invalid_reason(self):
self.assertFalse(
should_retry_eni(
EcsTaskFailToStart(
"The task failed to start due to: "
"CannotPullContainerError: "
"ref pull has been retried 5 time(s): failed to resolve reference"
)
assert not should_retry_eni(
EcsTaskFailToStart(
"The task failed to start due to: "
"CannotPullContainerError: "
"ref pull has been retried 5 time(s): failed to resolve reference"
)
)


class TestEcsTaskLogFetcher(unittest.TestCase):
class TestEcsTaskLogFetcher:
@mock.patch("logging.Logger")
def set_up_log_fetcher(self, logger_mock):
self.logger_mock = logger_mock
Expand All @@ -99,7 +94,7 @@ def set_up_log_fetcher(self, logger_mock):
logger=logger_mock,
)

def setUp(self):
def setup_method(self):
self.set_up_log_fetcher()

@mock.patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

from unittest import TestCase
from unittest.mock import Mock

import pytest
Expand All @@ -26,7 +25,7 @@
from airflow.providers.amazon.aws.hooks.elasticache_replication_group import ElastiCacheReplicationGroupHook


class TestElastiCacheReplicationGroupHook(TestCase):
class TestElastiCacheReplicationGroupHook:
REPLICATION_GROUP_ID = "test-elasticache-replication-group-hook"

REPLICATION_GROUP_CONFIG = {
Expand All @@ -44,7 +43,7 @@ class TestElastiCacheReplicationGroupHook(TestCase):
{"creating", "available", "modifying", "deleting", "create - failed", "snapshotting"}
)

def setUp(self):
def setup_method(self):
self.hook = ElastiCacheReplicationGroupHook()
# noinspection PyPropertyAccess
self.hook.conn = Mock()
Expand Down
9 changes: 5 additions & 4 deletions tests/providers/amazon/aws/hooks/test_emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import unittest
from unittest import mock

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
Expand Down Expand Up @@ -47,8 +46,8 @@
}


class TestEmrContainerHook(unittest.TestCase):
def setUp(self):
class TestEmrContainerHook:
def setup_method(self):
self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234")

def test_init(self):
Expand Down Expand Up @@ -110,7 +109,9 @@ def test_query_status_polling_with_timeout(self, mock_session):
mock_session.return_value = emr_session_mock
emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION

query_status = self.emr_containers.poll_query_status(job_id="job123456", max_polling_attempts=2)
query_status = self.emr_containers.poll_query_status(
job_id="job123456", max_polling_attempts=2, poll_interval=0
)
# should poll until max_tries is reached since query is in non-terminal state
assert emr_client_mock.describe_job_run.call_count == 2
assert query_status == "RUNNING"
Loading