From c4201d96c2f8104716bec6976c51f497edc66d5f Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Thu, 1 Dec 2022 22:09:49 +0400 Subject: [PATCH] Migrate amazon provider hooks tests from `unittests` to `pytest` --- .../providers/amazon/aws/hooks/test_athena.py | 9 +-- .../amazon/aws/hooks/test_base_aws.py | 10 +-- .../amazon/aws/hooks/test_batch_client.py | 11 ++-- .../amazon/aws/hooks/test_batch_waiters.py | 5 +- .../amazon/aws/hooks/test_datasync.py | 25 +++---- .../amazon/aws/hooks/test_dms_task.py | 5 +- .../amazon/aws/hooks/test_dynamodb.py | 5 +- tests/providers/amazon/aws/hooks/test_ecs.py | 35 +++++----- .../test_elasticache_replication_group.py | 5 +- .../amazon/aws/hooks/test_emr_containers.py | 9 +-- .../amazon/aws/hooks/test_glacier.py | 65 ++++++++----------- .../amazon/aws/hooks/test_glue_crawler.py | 38 ++++------- .../amazon/aws/hooks/test_redshift_sql.py | 14 ++-- .../amazon/aws/utils/test_emailer.py | 4 +- .../amazon/aws/utils/test_redshift.py | 3 +- .../providers/amazon/aws/utils/test_utils.py | 3 +- 16 files changed, 96 insertions(+), 150 deletions(-) diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py index 65d549f925909..7a2b3be99752e 100644 --- a/tests/providers/amazon/aws/hooks/test_athena.py +++ b/tests/providers/amazon/aws/hooks/test_athena.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.amazon.aws.hooks.athena import AthenaHook @@ -48,8 +47,8 @@ } -class TestAthenaHook(unittest.TestCase): - def setUp(self): +class TestAthenaHook: + def setup_method(self): self.athena = AthenaHook(sleep_time=0) def test_init(self): @@ -175,7 +174,3 @@ def test_hook_get_output_location(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_QUERY_EXECUTION_OUTPUT result = self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"]) assert result == "s3://test_bucket/test.csv" - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py b/tests/providers/amazon/aws/hooks/test_base_aws.py index ad5a0c147e227..837a3d2f899b3 100644 --- a/tests/providers/amazon/aws/hooks/test_base_aws.py +++ b/tests/providers/amazon/aws/hooks/test_base_aws.py @@ -19,7 +19,6 @@ import json import os -import unittest from base64 import b64encode from datetime import datetime, timedelta, timezone from unittest import mock @@ -606,7 +605,7 @@ def mock_refresh_credentials(): def test_connection_region_name( self, conn_type, connection_uri, region_name, env_region, expected_region_name ): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri, AWS_DEFAULT_REGION=env_region ): if conn_type == "client": @@ -629,10 +628,7 @@ def test_connection_region_name( ], ) def test_connection_aws_partition(self, conn_type, connection_uri, expected_partition): - with unittest.mock.patch.dict( - "os.environ", - AIRFLOW_CONN_TEST_CONN=connection_uri, - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=connection_uri): if conn_type == "client": hook = AwsBaseHook(aws_conn_id="test_conn", client_type="dynamodb") elif conn_type == "resource": @@ -772,7 +768,7 @@ def test_resolve_verify(self, verify, conn_verify): extra={"verify": conn_verify} if conn_verify is not None else {}, ) - with unittest.mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=mock_conn.get_uri()): hook = AwsBaseHook(aws_conn_id="test_conn", verify=verify) expected = verify if verify is not None else conn_verify assert hook.verify == expected diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py index c1ea153dfde47..13726e5518ff4 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_client.py +++ b/tests/providers/amazon/aws/hooks/test_batch_client.py @@ -284,11 +284,11 @@ def test_job_no_awslogs_stream(self, caplog): } ] } - with caplog.at_level(level=logging.getLevelName("WARNING")): + + with caplog.at_level(level=logging.WARNING): assert self.batch_client.get_job_awslogs_info(JOB_ID) is None assert len(caplog.records) == 1 - log_record = caplog.records[0] - assert "doesn't create AWS CloudWatch Stream" in log_record.message + assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0] def test_job_splunk_logs(self, caplog): self.client_mock.describe_jobs.return_value = { @@ -304,11 +304,10 @@ def test_job_splunk_logs(self, caplog): } ] } - with caplog.at_level(level=logging.getLevelName("WARNING")): + with caplog.at_level(level=logging.WARNING): assert self.batch_client.get_job_awslogs_info(JOB_ID) is None assert len(caplog.records) == 1 - log_record = caplog.records[0] - assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in log_record.message + assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0] class TestBatchClientDelays: diff --git a/tests/providers/amazon/aws/hooks/test_batch_waiters.py b/tests/providers/amazon/aws/hooks/test_batch_waiters.py index 3ff9a154fd223..c245ac4da4c9c 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_waiters.py +++ b/tests/providers/amazon/aws/hooks/test_batch_waiters.py @@ -30,7 +30,6 @@ from __future__ import annotations import inspect -import unittest from typing import NamedTuple from unittest import mock @@ -317,12 +316,12 @@ def test_batch_job_waiting(aws_clients, aws_region, job_queue_name, job_definiti assert job_status == "SUCCEEDED" -class TestBatchWaiters(unittest.TestCase): +class TestBatchWaiters: @mock.patch.dict("os.environ", AWS_DEFAULT_REGION=AWS_REGION) @mock.patch.dict("os.environ", AWS_ACCESS_KEY_ID=AWS_ACCESS_KEY_ID) @mock.patch.dict("os.environ", AWS_SECRET_ACCESS_KEY=AWS_SECRET_ACCESS_KEY) @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def setUp(self, get_client_type_mock): + def setup_method(self, method, get_client_type_mock): self.job_id = "8ba9d676-4108-4474-9dca-8bbac1da9b19" self.region_name = AWS_REGION diff --git a/tests/providers/amazon/aws/hooks/test_datasync.py b/tests/providers/amazon/aws/hooks/test_datasync.py index f68b441de8954..eeb976e4e0ab3 100644 --- a/tests/providers/amazon/aws/hooks/test_datasync.py +++ b/tests/providers/amazon/aws/hooks/test_datasync.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import boto3 @@ -29,7 +28,7 @@ @mock_datasync -class TestDataSyncHook(unittest.TestCase): +class TestDataSyncHook: def test_get_conn(self): hook = DataSyncHook(aws_conn_id="aws_default") assert hook.get_conn() is not None @@ -50,21 +49,13 @@ def test_get_conn(self): @mock_datasync @mock.patch.object(DataSyncHook, "get_conn") -class TestDataSyncHookMocked(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.source_server_hostname = "host" - self.source_subdirectory = "somewhere" - self.destination_bucket_name = "my_bucket" - self.destination_bucket_dir = "dir" +class TestDataSyncHookMocked: + source_server_hostname = "host" + source_subdirectory = "somewhere" + destination_bucket_name = "my_bucket" + destination_bucket_dir = "dir" - self.client = None - self.hook = None - self.source_location_arn = None - self.destination_location_arn = None - self.task_arn = None - - def setUp(self): + def setup_method(self, method): self.client = boto3.client("datasync", region_name="us-east-1") self.hook = DataSyncHook(aws_conn_id="aws_default", wait_interval_seconds=0) @@ -86,7 +77,7 @@ def setUp(self): DestinationLocationArn=self.destination_location_arn, )["TaskArn"] - def tearDown(self): + def teardown_method(self, method): # Delete all tasks: tasks = self.client.list_tasks() for task in tasks["Tasks"]: diff --git a/tests/providers/amazon/aws/hooks/test_dms_task.py b/tests/providers/amazon/aws/hooks/test_dms_task.py index efe5561cd3716..9d66df55c25eb 100644 --- a/tests/providers/amazon/aws/hooks/test_dms_task.py +++ b/tests/providers/amazon/aws/hooks/test_dms_task.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import unittest from typing import Any from unittest import mock @@ -68,8 +67,8 @@ MOCK_DELETE_RESPONSE: dict[str, Any] = {"ReplicationTask": {**MOCK_TASK_RESPONSE_DATA, "Status": "deleting"}} -class TestDmsHook(unittest.TestCase): - def setUp(self): +class TestDmsHook: + def setup_method(self): self.dms = DmsHook() def test_init(self): diff --git a/tests/providers/amazon/aws/hooks/test_dynamodb.py b/tests/providers/amazon/aws/hooks/test_dynamodb.py index 7c06c4c304bb5..8c5886639c0cc 100644 --- a/tests/providers/amazon/aws/hooks/test_dynamodb.py +++ b/tests/providers/amazon/aws/hooks/test_dynamodb.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest import uuid from moto import mock_dynamodb @@ -25,7 +24,7 @@ from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook -class TestDynamoDBHook(unittest.TestCase): +class TestDynamoDBHook: @mock_dynamodb def test_get_conn_returns_a_boto3_connection(self): hook = DynamoDBHook(aws_conn_id="aws_default") @@ -39,7 +38,7 @@ def test_insert_batch_items_dynamodb_table(self): ) # this table needs to be created in production - table = hook.get_conn().create_table( + hook.get_conn().create_table( TableName="test_airflow", KeySchema=[ {"AttributeName": "id", "KeyType": "HASH"}, diff --git a/tests/providers/amazon/aws/hooks/test_ecs.py b/tests/providers/amazon/aws/hooks/test_ecs.py index b7477c372eccf..d9a4f53fa8a75 100644 --- a/tests/providers/amazon/aws/hooks/test_ecs.py +++ b/tests/providers/amazon/aws/hooks/test_ecs.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import timedelta from unittest import mock @@ -56,38 +55,34 @@ def test_get_task_state(self, mock_conn) -> None: assert EcsHook().get_task_state(cluster="cluster_name", task="task_name") == "ACTIVE" -class TestShouldRetry(unittest.TestCase): +class TestShouldRetry: def test_return_true_on_valid_reason(self): - self.assertTrue(should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo"))) + assert should_retry(EcsOperatorError([{"reason": "RESOURCE:MEMORY"}], "Foo")) def test_return_false_on_invalid_reason(self): - self.assertFalse(should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo"))) + assert not should_retry(EcsOperatorError([{"reason": "CLUSTER_NOT_FOUND"}], "Foo")) -class TestShouldRetryEni(unittest.TestCase): +class TestShouldRetryEni: def test_return_true_on_valid_reason(self): - self.assertTrue( - should_retry_eni( - EcsTaskFailToStart( - "The task failed to start due to: " - "Timeout waiting for network interface provisioning to complete." - ) + assert should_retry_eni( + EcsTaskFailToStart( + "The task failed to start due to: " + "Timeout waiting for network interface provisioning to complete." ) ) def test_return_false_on_invalid_reason(self): - self.assertFalse( - should_retry_eni( - EcsTaskFailToStart( - "The task failed to start due to: " - "CannotPullContainerError: " - "ref pull has been retried 5 time(s): failed to resolve reference" - ) + assert not should_retry_eni( + EcsTaskFailToStart( + "The task failed to start due to: " + "CannotPullContainerError: " + "ref pull has been retried 5 time(s): failed to resolve reference" ) ) -class TestEcsTaskLogFetcher(unittest.TestCase): +class TestEcsTaskLogFetcher: @mock.patch("logging.Logger") def set_up_log_fetcher(self, logger_mock): self.logger_mock = logger_mock @@ -99,7 +94,7 @@ def set_up_log_fetcher(self, logger_mock): logger=logger_mock, ) - def setUp(self): + def setup_method(self): self.set_up_log_fetcher() @mock.patch( diff --git a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py index 18766a1e7f3e2..8c72720ddc82a 100644 --- a/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py +++ b/tests/providers/amazon/aws/hooks/test_elasticache_replication_group.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -from unittest import TestCase from unittest.mock import Mock import pytest @@ -26,7 +25,7 @@ from airflow.providers.amazon.aws.hooks.elasticache_replication_group import ElastiCacheReplicationGroupHook -class TestElastiCacheReplicationGroupHook(TestCase): +class TestElastiCacheReplicationGroupHook: REPLICATION_GROUP_ID = "test-elasticache-replication-group-hook" REPLICATION_GROUP_CONFIG = { @@ -44,7 +43,7 @@ class TestElastiCacheReplicationGroupHook(TestCase): {"creating", "available", "modifying", "deleting", "create - failed", "snapshotting"} ) - def setUp(self): + def setup_method(self): self.hook = ElastiCacheReplicationGroupHook() # noinspection PyPropertyAccess self.hook.conn = Mock() diff --git a/tests/providers/amazon/aws/hooks/test_emr_containers.py b/tests/providers/amazon/aws/hooks/test_emr_containers.py index 7bd1255b078b0..8a5f1303a6921 100644 --- a/tests/providers/amazon/aws/hooks/test_emr_containers.py +++ b/tests/providers/amazon/aws/hooks/test_emr_containers.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook @@ -47,8 +46,8 @@ } -class TestEmrContainerHook(unittest.TestCase): - def setUp(self): +class TestEmrContainerHook: + def setup_method(self): self.emr_containers = EmrContainerHook(virtual_cluster_id="vc1234") def test_init(self): @@ -110,7 +109,9 @@ def test_query_status_polling_with_timeout(self, mock_session): mock_session.return_value = emr_session_mock emr_client_mock.describe_job_run.return_value = JOB2_RUN_DESCRIPTION - query_status = self.emr_containers.poll_query_status(job_id="job123456", max_polling_attempts=2) + query_status = self.emr_containers.poll_query_status( + job_id="job123456", max_polling_attempts=2, poll_interval=0 + ) # should poll until max_tries is reached since query is in non-terminal state assert emr_client_mock.describe_job_run.call_count == 2 assert query_status == "RUNNING" diff --git a/tests/providers/amazon/aws/hooks/test_glacier.py b/tests/providers/amazon/aws/hooks/test_glacier.py index 28f0bca78ce11..cce5669960d78 100644 --- a/tests/providers/amazon/aws/hooks/test_glacier.py +++ b/tests/providers/amazon/aws/hooks/test_glacier.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -import unittest +import logging from unittest import mock from airflow.providers.amazon.aws.hooks.glacier import GlacierHook @@ -30,8 +30,8 @@ JOB_STATUS = {"Action": "", "StatusCode": "Succeeded"} -class TestAmazonGlacierHook(unittest.TestCase): - def setUp(self): +class TestAmazonGlacierHook: + def setup_method(self): with mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.__init__", return_value=None): self.hook = GlacierHook(aws_conn_id="aws_default") @@ -47,25 +47,21 @@ def test_retrieve_inventory_should_return_job_id(self, mock_conn): assert job_id == result @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn") - def test_retrieve_inventory_should_log_mgs(self, mock_conn): + def test_retrieve_inventory_should_log_mgs(self, mock_conn, caplog): # given job_id = {"jobId": "1234abcd"} # when - with self.assertLogs() as log: + + with caplog.at_level(logging.INFO, logger=self.hook.log.name): + caplog.clear() mock_conn.return_value.initiate_job.return_value = job_id self.hook.retrieve_inventory(VAULT_NAME) - # then - self.assertEqual( - log.output, - [ - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Retrieving inventory for vault: {VAULT_NAME}", - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Initiated inventory-retrieval job for: {VAULT_NAME}", - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Retrieval Job ID: {job_id.get('jobId')}", - ], - ) + # then + assert caplog.messages == [ + f"Retrieving inventory for vault: {VAULT_NAME}", + f"Initiated inventory-retrieval job for: {VAULT_NAME}", + f"Retrieval Job ID: {job_id.get('jobId')}", + ] @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn") def test_retrieve_inventory_results_should_return_response(self, mock_conn): @@ -77,19 +73,14 @@ def test_retrieve_inventory_results_should_return_response(self, mock_conn): assert response == RESPONSE_BODY @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn") - def test_retrieve_inventory_results_should_log_mgs(self, mock_conn): + def test_retrieve_inventory_results_should_log_mgs(self, mock_conn, caplog): # when - with self.assertLogs() as log: + with caplog.at_level(logging.INFO, logger=self.hook.log.name): + caplog.clear() mock_conn.return_value.get_job_output.return_value = REQUEST_RESULT self.hook.retrieve_inventory_results(VAULT_NAME, JOB_ID) - # then - self.assertEqual( - log.output, - [ - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Retrieving the job results for vault: {VAULT_NAME}...", - ], - ) + # then + assert caplog.messages == [f"Retrieving the job results for vault: {VAULT_NAME}..."] @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn") def test_describe_job_should_return_status_succeeded(self, mock_conn): @@ -101,18 +92,14 @@ def test_describe_job_should_return_status_succeeded(self, mock_conn): assert response == JOB_STATUS @mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.get_conn") - def test_describe_job_should_log_mgs(self, mock_conn): + def test_describe_job_should_log_mgs(self, mock_conn, caplog): # when - with self.assertLogs() as log: + with caplog.at_level(logging.INFO, logger=self.hook.log.name): + caplog.clear() mock_conn.return_value.describe_job.return_value = JOB_STATUS self.hook.describe_job(VAULT_NAME, JOB_ID) - # then - self.assertEqual( - log.output, - [ - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}", - "INFO:airflow.providers.amazon.aws.hooks.glacier.GlacierHook:" - f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}", - ], - ) + # then + assert caplog.messages == [ + f"Retrieving status for vault: {VAULT_NAME} and job {JOB_ID}", + f"Job status: {JOB_STATUS.get('Action')}, code status: {JOB_STATUS.get('StatusCode')}", + ] diff --git a/tests/providers/amazon/aws/hooks/test_glue_crawler.py b/tests/providers/amazon/aws/hooks/test_glue_crawler.py index ec966cb68316f..ac2d3cba2cea0 100644 --- a/tests/providers/amazon/aws/hooks/test_glue_crawler.py +++ b/tests/providers/amazon/aws/hooks/test_glue_crawler.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from copy import deepcopy from unittest import mock @@ -83,18 +82,16 @@ } -class TestGlueCrawlerHook(unittest.TestCase): - @classmethod - def setUp(cls): - cls.hook = GlueCrawlerHook(aws_conn_id="aws_default") +class TestGlueCrawlerHook: + def setup_method(self): + self.hook = GlueCrawlerHook(aws_conn_id="aws_default") def test_init(self): - self.assertEqual(self.hook.aws_conn_id, "aws_default") + assert self.hook.aws_conn_id == "aws_default" @mock.patch.object(GlueCrawlerHook, "get_conn") def test_has_crawler(self, mock_get_conn): - response = self.hook.has_crawler(mock_crawler_name) - self.assertEqual(response, True) + assert self.hook.has_crawler(mock_crawler_name) is True mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) @mock.patch.object(GlueCrawlerHook, "get_conn") @@ -104,8 +101,7 @@ class MockException(Exception): mock_get_conn.return_value.exceptions.EntityNotFoundException = MockException mock_get_conn.return_value.get_crawler.side_effect = MockException("AAA") - response = self.hook.has_crawler(mock_crawler_name) - self.assertEqual(response, False) + assert self.hook.has_crawler(mock_crawler_name) is False mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) @mock.patch.object(GlueCrawlerHook, "get_conn") @@ -114,30 +110,28 @@ def test_update_crawler_needed(self, mock_get_conn): mock_config_two = deepcopy(mock_config) mock_config_two["Role"] = "test-2-role" - response = self.hook.update_crawler(**mock_config_two) - self.assertEqual(response, True) + assert self.hook.update_crawler(**mock_config_two) is True mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) mock_get_conn.return_value.update_crawler.assert_called_once_with(**mock_config_two) @mock.patch.object(GlueCrawlerHook, "get_conn") def test_update_crawler_not_needed(self, mock_get_conn): mock_get_conn.return_value.get_crawler.return_value = {"Crawler": mock_config} - response = self.hook.update_crawler(**mock_config) - self.assertEqual(response, False) + assert self.hook.update_crawler(**mock_config) is False mock_get_conn.return_value.get_crawler.assert_called_once_with(Name=mock_crawler_name) @mock.patch.object(GlueCrawlerHook, "get_conn") def test_create_crawler(self, mock_get_conn): mock_get_conn.return_value.create_crawler.return_value = {"Crawler": {"Name": mock_crawler_name}} glue_crawler = self.hook.create_crawler(**mock_config) - self.assertIn("Crawler", glue_crawler) - self.assertIn("Name", glue_crawler["Crawler"]) - self.assertEqual(glue_crawler["Crawler"]["Name"], mock_crawler_name) + assert "Crawler" in glue_crawler + assert "Name" in glue_crawler["Crawler"] + assert glue_crawler["Crawler"]["Name"] == mock_crawler_name @mock.patch.object(GlueCrawlerHook, "get_conn") def test_start_crawler(self, mock_get_conn): result = self.hook.start_crawler(mock_crawler_name) - self.assertEqual(result, mock_get_conn.return_value.start_crawler.return_value) + assert result == mock_get_conn.return_value.start_crawler.return_value mock_get_conn.return_value.start_crawler.assert_called_once_with(Name=mock_crawler_name) @@ -159,7 +153,7 @@ def test_wait_for_crawler_completion_instant_ready(self, mock_get_conn, mock_get ] } result = self.hook.wait_for_crawler_completion(mock_crawler_name) - self.assertEqual(result, "MOCK_STATUS") + assert result == "MOCK_STATUS" mock_get_conn.assert_has_calls( [ mock.call(), @@ -195,7 +189,7 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_ }, ] result = self.hook.wait_for_crawler_completion(mock_crawler_name) - self.assertEqual(result, "MOCK_STATUS") + assert result == "MOCK_STATUS" mock_get_conn.assert_has_calls( [ mock.call(), @@ -208,7 +202,3 @@ def test_wait_for_crawler_completion_retry_two_times(self, mock_sleep, mock_get_ mock.call(mock_crawler_name), ] ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/providers/amazon/aws/hooks/test_redshift_sql.py b/tests/providers/amazon/aws/hooks/test_redshift_sql.py index e3af91c9e7459..531d6a9b470a0 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_sql.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_sql.py @@ -17,19 +17,16 @@ from __future__ import annotations import json -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.models import Connection from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook -class TestRedshiftSQLHookConn(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestRedshiftSQLHookConn: + def setup_method(self): self.connection = Connection( conn_type="redshift", login="login", password="password", host="host", port=5439, schema="dev" ) @@ -71,7 +68,8 @@ def test_get_conn_extra(self, mock_connect): iam=True, ) - @parameterized.expand( + @pytest.mark.parametrize( + "conn_params, conn_extra, expected_call_args", [ ({}, {}, {}), ({"login": "test"}, {}, {"user": "test"}), @@ -81,7 +79,7 @@ def test_get_conn_extra(self, mock_connect): ], ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.redshift_connector.connect") - def test_get_conn_overrides_correctly(self, conn_params, conn_extra, expected_call_args, mock_connect): + def test_get_conn_overrides_correctly(self, mock_connect, conn_params, conn_extra, expected_call_args): with mock.patch( "airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.conn", Connection(conn_type="redshift", extra=conn_extra, **conn_params), diff --git a/tests/providers/amazon/aws/utils/test_emailer.py b/tests/providers/amazon/aws/utils/test_emailer.py index 0e3b2ed7d4578..e51a885b67a45 100644 --- a/tests/providers/amazon/aws/utils/test_emailer.py +++ b/tests/providers/amazon/aws/utils/test_emailer.py @@ -17,14 +17,14 @@ # under the License. from __future__ import annotations -from unittest import TestCase, mock +from unittest import mock import pytest from airflow.providers.amazon.aws.utils.emailer import send_email -class TestSendEmailSes(TestCase): +class TestSendEmailSes: @mock.patch("airflow.providers.amazon.aws.utils.emailer.SesHook") def test_send_ses_email(self, mock_hook): send_email( diff --git a/tests/providers/amazon/aws/utils/test_redshift.py b/tests/providers/amazon/aws/utils/test_redshift.py index f255e6bfd0091..9d546c13d1392 100644 --- a/tests/providers/amazon/aws/utils/test_redshift.py +++ b/tests/providers/amazon/aws/utils/test_redshift.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from boto3.session import Session @@ -25,7 +24,7 @@ from airflow.providers.amazon.aws.utils.redshift import build_credentials_block -class TestS3ToRedshiftTransfer(unittest.TestCase): +class TestS3ToRedshiftTransfer: @mock.patch("boto3.session.Session") def test_build_credentials_block(self, mock_session): access_key = "aws_access_key_id" diff --git a/tests/providers/amazon/aws/utils/test_utils.py b/tests/providers/amazon/aws/utils/test_utils.py index ced274c6e7a69..6cf8bbb23ef25 100644 --- a/tests/providers/amazon/aws/utils/test_utils.py +++ b/tests/providers/amazon/aws/utils/test_utils.py @@ -17,7 +17,6 @@ from __future__ import annotations from datetime import datetime -from unittest import TestCase import pytz @@ -33,7 +32,7 @@ EPOCH = 946_684_800 -class TestUtils(TestCase): +class TestUtils: def test_trim_none_values(self): input_object = { "test": "test",