From d554922ceae597d81e4f45d4fb9e024eb0bc3c83 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Tue, 15 Nov 2022 16:27:13 +0400 Subject: [PATCH 1/7] Replace `unittests` in providers tests by pure `pytest` [Wave-3] --- tests/providers/airbyte/hooks/test_airbyte.py | 35 ++++--- .../airbyte/operators/test_airbyte.py | 3 +- .../providers/airbyte/sensors/test_airbyte.py | 3 +- .../providers/alibaba/cloud/hooks/test_oss.py | 5 +- .../cloud/log/test_oss_task_handler.py | 5 +- .../alibaba/cloud/operators/test_oss.py | 13 ++- .../alibaba/cloud/sensors/test_oss_key.py | 5 +- .../providers/arangodb/hooks/test_arangodb.py | 6 +- .../arangodb/operators/test_arangodb.py | 3 +- .../arangodb/sensors/test_arangodb.py | 5 +- .../asana/operators/test_asana_tasks.py | 5 +- .../atlassian/jira/hooks/test_jira.py | 5 +- .../atlassian/jira/operators/test_jira.py | 5 +- .../atlassian/jira/sensors/test_jira.py | 5 +- .../celery/sensors/test_celery_queue.py | 5 +- .../providers/cloudant/hooks/test_cloudant.py | 5 +- .../operators/test_spark_kubernetes.py | 5 +- .../sensors/test_spark_kubernetes.py | 5 +- .../providers/common/sql/hooks/test_dbapi.py | 7 +- .../common/sql/operators/test_sql.py | 29 +++--- .../databricks/hooks/test_databricks.py | 37 ++++---- .../databricks/hooks/test_databricks_sql.py | 5 +- .../databricks/operators/test_databricks.py | 25 +++-- .../operators/test_databricks_repos.py | 7 +- .../operators/test_databricks_sql.py | 5 +- .../providers/databricks/utils/databricks.py | 6 +- tests/providers/datadog/hooks/test_datadog.py | 29 +++--- .../providers/datadog/sensors/test_datadog.py | 5 +- .../providers/dingding/hooks/test_dingding.py | 5 +- .../dingding/operators/test_dingding.py | 5 +- .../discord/hooks/test_discord_webhook.py | 5 +- .../discord/operators/test_discord_webhook.py | 6 +- tests/providers/docker/hooks/test_docker.py | 11 +-- .../providers/docker/operators/test_docker.py | 35 ++++--- .../docker/operators/test_docker_swarm.py | 8 +- .../elasticsearch/hooks/test_elasticsearch.py | 18 ++-- tests/providers/exasol/hooks/test_exasol.py | 13 +-- .../providers/exasol/operators/test_exasol.py | 3 +- tests/providers/ftp/hooks/test_ftp.py | 14 +-- tests/providers/ftp/sensors/test_ftp.py | 3 +- tests/providers/grpc/operators/test_grpc.py | 3 +- .../_internal_client/test_vault_client.py | 3 +- tests/providers/hashicorp/hooks/test_vault.py | 25 ++--- .../providers/hashicorp/secrets/test_vault.py | 4 +- tests/providers/http/hooks/test_http.py | 93 ++++++++----------- tests/providers/http/operators/test_http.py | 19 ++-- tests/providers/http/sensors/test_http.py | 5 +- tests/providers/imap/hooks/test_imap.py | 5 +- .../imap/sensors/test_imap_attachment.py | 11 +-- .../providers/influxdb/hooks/test_influxdb.py | 6 +- .../influxdb/operators/test_influxdb.py | 3 +- tests/providers/jdbc/operators/test_jdbc.py | 5 +- tests/providers/jenkins/hooks/test_jenkins.py | 9 +- .../operators/test_jenkins_job_trigger.py | 77 ++++----------- .../providers/jenkins/sensors/test_jenkins.py | 12 +-- tests/providers/mongo/hooks/test_mongo.py | 44 +++------ tests/providers/mongo/sensors/test_mongo.py | 8 +- tests/providers/mysql/hooks/test_mysql.py | 51 +++------- tests/providers/mysql/operators/test_mysql.py | 29 ++---- .../mysql/transfers/test_s3_to_mysql.py | 7 +- .../mysql/transfers/test_vertica_to_mysql.py | 5 +- tests/providers/neo4j/hooks/test_neo4j.py | 28 +++--- tests/providers/neo4j/operators/test_neo4j.py | 3 +- .../providers/openfaas/hooks/test_openfaas.py | 62 ++++++------- .../providers/opsgenie/hooks/test_opsgenie.py | 5 +- .../opsgenie/operators/test_opsgenie.py | 13 ++- tests/providers/oracle/hooks/test_oracle.py | 21 +---- .../providers/oracle/operators/test_oracle.py | 22 +++-- .../oracle/transfers/test_oracle_to_oracle.py | 6 +- .../papermill/operators/test_papermill.py | 3 +- .../providers/postgres/hooks/test_postgres.py | 28 ++---- .../postgres/operators/test_postgres.py | 8 +- tests/providers/presto/hooks/test_presto.py | 15 ++- .../presto/transfers/test_gcs_presto.py | 3 +- tests/providers/qubole/hooks/test_qubole.py | 4 +- .../qubole/hooks/test_qubole_check.py | 4 +- .../qubole/operators/test_qubole_check.py | 5 +- tests/providers/qubole/sensors/test_qubole.py | 5 +- tests/providers/redis/hooks/test_redis.py | 3 +- .../redis/operators/test_redis_publish.py | 3 +- .../providers/redis/sensors/test_redis_key.py | 6 +- .../redis/sensors/test_redis_pub_sub.py | 5 +- .../operators/test_salesforce_apex_rest.py | 3 +- .../providers/salesforce/sensors/__init__.py | 16 ---- tests/providers/samba/hooks/test_samba.py | 14 +-- tests/providers/segment/hooks/test_segment.py | 7 +- .../operators/test_segment_track_event.py | 8 +- .../providers/sendgrid/utils/test_emailer.py | 5 +- tests/providers/sftp/hooks/test_sftp.py | 34 +++---- tests/providers/sftp/sensors/test_sftp.py | 3 +- .../singularity/operators/test_singularity.py | 25 ++--- .../snowflake/operators/test_snowflake.py | 6 +- .../transfers/test_copy_into_snowflake.py | 3 +- tests/providers/sqlite/hooks/test_sqlite.py | 9 +- .../providers/sqlite/operators/test_sqlite.py | 8 +- tests/providers/ssh/hooks/test_ssh.py | 15 ++- tests/providers/tableau/hooks/test_tableau.py | 23 ++--- .../tableau/operators/test_tableau.py | 9 +- .../providers/tableau/sensors/test_tableau.py | 16 ++-- .../providers/telegram/hooks/test_telegram.py | 5 +- .../telegram/operators/test_telegram.py | 5 +- tests/providers/trino/hooks/test_trino.py | 20 ++-- tests/providers/trino/operators/test_trino.py | 18 ++-- .../trino/transfers/test_gcs_trino.py | 3 +- tests/providers/vertica/hooks/test_vertica.py | 13 +-- .../vertica/operators/test_vertica.py | 3 +- .../yandex/hooks/test_yandexcloud_dataproc.py | 9 +- .../operators/test_yandexcloud_dataproc.py | 5 +- 108 files changed, 526 insertions(+), 839 deletions(-) delete mode 100644 tests/providers/salesforce/sensors/__init__.py diff --git a/tests/providers/airbyte/hooks/test_airbyte.py b/tests/providers/airbyte/hooks/test_airbyte.py index af265e1ed136c..271b531822581 100644 --- a/tests/providers/airbyte/hooks/test_airbyte.py +++ b/tests/providers/airbyte/hooks/test_airbyte.py @@ -17,11 +17,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.models import Connection @@ -29,7 +27,7 @@ from airflow.utils import db -class TestAirbyteHook(unittest.TestCase): +class TestAirbyteHook: """ Test all functions from Airbyte Hook """ @@ -46,7 +44,7 @@ class TestAirbyteHook(unittest.TestCase): _mock_job_status_success_response_body = {"job": {"status": "succeeded"}} _mock_job_cancel_status = "cancelled" - def setUp(self): + def setup_method(self): db.merge_conn( Connection( conn_id="airbyte_conn_id_test", conn_type="airbyte", host="http://test-airbyte", port=8001 @@ -59,25 +57,26 @@ def return_value_get_job(self, status): response.json.return_value = {"job": {"status": status}} return response - @requests_mock.mock() - def test_submit_sync_connection(self, m): - m.post( + def test_submit_sync_connection(self, requests_mock): + requests_mock.post( self.sync_connection_endpoint, status_code=200, json=self._mock_sync_conn_success_response_body ) resp = self.hook.submit_sync_connection(connection_id=self.connection_id) assert resp.status_code == 200 assert resp.json() == self._mock_sync_conn_success_response_body - @requests_mock.mock() - def test_get_job_status(self, m): - m.post(self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body) + def test_get_job_status(self, requests_mock): + requests_mock.post( + self.get_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body + ) resp = self.hook.get_job(job_id=self.job_id) assert resp.status_code == 200 assert resp.json() == self._mock_job_status_success_response_body - @requests_mock.mock() - def test_cancel_job(self, m): - m.post(self.cancel_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body) + def test_cancel_job(self, requests_mock): + requests_mock.post( + self.cancel_job_endpoint, status_code=200, json=self._mock_job_status_success_response_body + ) resp = self.hook.cancel_job(job_id=self.job_id) assert resp.status_code == 200 @@ -147,9 +146,8 @@ def test_wait_for_job_cancelled(self, mock_get_job): calls = [mock.call(job_id=self.job_id), mock.call(job_id=self.job_id)] mock_get_job.assert_has_calls(calls) - @requests_mock.mock() - def test_connection_success(self, m): - m.get( + def test_connection_success(self, requests_mock): + requests_mock.get( self.health_endpoint, status_code=200, ) @@ -158,9 +156,8 @@ def test_connection_success(self, m): assert status is True assert msg == "Connection successfully tested" - @requests_mock.mock() - def test_connection_failure(self, m): - m.get(self.health_endpoint, status_code=500, json={"message": "internal server error"}) + def test_connection_failure(self, requests_mock): + requests_mock.get(self.health_endpoint, status_code=500, json={"message": "internal server error"}) status, msg = self.hook.test_connection() assert status is False diff --git a/tests/providers/airbyte/operators/test_airbyte.py b/tests/providers/airbyte/operators/test_airbyte.py index d0f7fbe6b7462..f8ecd15615c8d 100644 --- a/tests/providers/airbyte/operators/test_airbyte.py +++ b/tests/providers/airbyte/operators/test_airbyte.py @@ -17,13 +17,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator -class TestAirbyteTriggerSyncOp(unittest.TestCase): +class TestAirbyteTriggerSyncOp: """ Test execute function from Airbyte Operator """ diff --git a/tests/providers/airbyte/sensors/test_airbyte.py b/tests/providers/airbyte/sensors/test_airbyte.py index 31e2f17de0c65..f6fd5ef972f53 100644 --- a/tests/providers/airbyte/sensors/test_airbyte.py +++ b/tests/providers/airbyte/sensors/test_airbyte.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -25,7 +24,7 @@ from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor -class TestAirbyteJobSensor(unittest.TestCase): +class TestAirbyteJobSensor: task_id = "task-id" airbyte_conn_id = "airbyte-conn-test" diff --git a/tests/providers/alibaba/cloud/hooks/test_oss.py b/tests/providers/alibaba/cloud/hooks/test_oss.py index e03c0241fe842..1c47aa10c9741 100644 --- a/tests/providers/alibaba/cloud/hooks/test_oss.py +++ b/tests/providers/alibaba/cloud/hooks/test_oss.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.alibaba.cloud.hooks.oss import OSSHook @@ -32,8 +31,8 @@ MOCK_FILE_PATH = "mock_file_path" -class TestOSSHook(unittest.TestCase): - def setUp(self): +class TestOSSHook: + def setup_method(self): with mock.patch( OSS_STRING.format("OSSHook.__init__"), new=mock_oss_hook_default_project_id, diff --git a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py index e1d4cbf1c2066..2cf999849143f 100644 --- a/tests/providers/alibaba/cloud/log/test_oss_task_handler.py +++ b/tests/providers/alibaba/cloud/log/test_oss_task_handler.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import PropertyMock @@ -32,8 +31,8 @@ MOCK_FILE_PATH = "mock_file_path" -class TestOSSTaskHandler(unittest.TestCase): - def setUp(self): +class TestOSSTaskHandler: + def setup_method(self): self.base_log_folder = "local/airflow/logs/1.log" self.oss_log_folder = f"oss://{MOCK_BUCKET_NAME}/airflow/logs" self.oss_task_handler = OSSTaskHandler(self.base_log_folder, self.oss_log_folder) diff --git a/tests/providers/alibaba/cloud/operators/test_oss.py b/tests/providers/alibaba/cloud/operators/test_oss.py index 0b42db1013dec..e95d5cbe34009 100644 --- a/tests/providers/alibaba/cloud/operators/test_oss.py +++ b/tests/providers/alibaba/cloud/operators/test_oss.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.alibaba.cloud.operators.oss import ( @@ -38,7 +37,7 @@ MOCK_CONTENT = "mock_content" -class TestOSSCreateBucketOperator(unittest.TestCase): +class TestOSSCreateBucketOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSCreateBucketOperator( @@ -49,7 +48,7 @@ def test_execute(self, mock_hook): mock_hook.return_value.create_bucket.assert_called_once_with(bucket_name=MOCK_BUCKET) -class TestOSSDeleteBucketOperator(unittest.TestCase): +class TestOSSDeleteBucketOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSDeleteBucketOperator( @@ -60,7 +59,7 @@ def test_execute(self, mock_hook): mock_hook.return_value.delete_bucket.assert_called_once_with(bucket_name=MOCK_BUCKET) -class TestOSSUploadObjectOperator(unittest.TestCase): +class TestOSSUploadObjectOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSUploadObjectOperator( @@ -78,7 +77,7 @@ def test_execute(self, mock_hook): ) -class TestOSSDownloadObjectOperator(unittest.TestCase): +class TestOSSDownloadObjectOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSDownloadObjectOperator( @@ -96,7 +95,7 @@ def test_execute(self, mock_hook): ) -class TestOSSDeleteBatchObjectOperator(unittest.TestCase): +class TestOSSDeleteBatchObjectOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSDeleteBatchObjectOperator( @@ -111,7 +110,7 @@ def test_execute(self, mock_hook): mock_hook.return_value.delete_objects.assert_called_once_with(bucket_name=MOCK_BUCKET, key=MOCK_KEYS) -class TestOSSDeleteObjectOperator(unittest.TestCase): +class TestOSSDeleteObjectOperator: @mock.patch("airflow.providers.alibaba.cloud.operators.oss.OSSHook") def test_execute(self, mock_hook): operator = OSSDeleteObjectOperator( diff --git a/tests/providers/alibaba/cloud/sensors/test_oss_key.py b/tests/providers/alibaba/cloud/sensors/test_oss_key.py index e191598d565d2..4304f37a52fdb 100644 --- a/tests/providers/alibaba/cloud/sensors/test_oss_key.py +++ b/tests/providers/alibaba/cloud/sensors/test_oss_key.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import PropertyMock @@ -33,8 +32,8 @@ MOCK_CONTENT = "mock_content" -class TestOSSKeySensor(unittest.TestCase): - def setUp(self): +class TestOSSKeySensor: + def setup_method(self): self.sensor = OSSKeySensor( bucket_key=MOCK_KEY, oss_conn_id=MOCK_OSS_CONN_ID, diff --git a/tests/providers/arangodb/hooks/test_arangodb.py b/tests/providers/arangodb/hooks/test_arangodb.py index 748cc22a28bec..707e8a59d20d9 100644 --- a/tests/providers/arangodb/hooks/test_arangodb.py +++ b/tests/providers/arangodb/hooks/test_arangodb.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -26,9 +25,8 @@ arangodb_client_mock = Mock(name="arangodb_client_for_test") -class TestArangoDBHook(unittest.TestCase): - def setUp(self): - super().setUp() +class TestArangoDBHook: + def setup_method(self): db.merge_conn( Connection( conn_id="arangodb_default", diff --git a/tests/providers/arangodb/operators/test_arangodb.py b/tests/providers/arangodb/operators/test_arangodb.py index 60f0bd76c83dd..190f5d9f5a23e 100644 --- a/tests/providers/arangodb/operators/test_arangodb.py +++ b/tests/providers/arangodb/operators/test_arangodb.py @@ -16,13 +16,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.arangodb.operators.arangodb import AQLOperator -class TestAQLOperator(unittest.TestCase): +class TestAQLOperator: @mock.patch("airflow.providers.arangodb.operators.arangodb.ArangoDBHook") def test_arangodb_operator_test(self, mock_hook): diff --git a/tests/providers/arangodb/sensors/test_arangodb.py b/tests/providers/arangodb/sensors/test_arangodb.py index 95c948adc6856..5b9273b77439b 100644 --- a/tests/providers/arangodb/sensors/test_arangodb.py +++ b/tests/providers/arangodb/sensors/test_arangodb.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -29,8 +28,8 @@ arangodb_hook_mock = Mock(name="arangodb_hook_for_test", **{"query.return_value.count.return_value": 1}) -class TestAQLSensor(unittest.TestCase): - def setUp(self): +class TestAQLSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG("test_dag_id", default_args=args) self.dag = dag diff --git a/tests/providers/asana/operators/test_asana_tasks.py b/tests/providers/asana/operators/test_asana_tasks.py index 157e8c8e700d0..6b7c373d69fcb 100644 --- a/tests/providers/asana/operators/test_asana_tasks.py +++ b/tests/providers/asana/operators/test_asana_tasks.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -34,12 +33,12 @@ asana_client_mock = Mock(name="asana_client_for_test") -class TestAsanaTaskOperators(unittest.TestCase): +class TestAsanaTaskOperators: """ Test that the AsanaTaskOperators are using the python-asana methods as expected. """ - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag diff --git a/tests/providers/atlassian/jira/hooks/test_jira.py b/tests/providers/atlassian/jira/hooks/test_jira.py index b5229fba67b3a..a8069357b4555 100644 --- a/tests/providers/atlassian/jira/hooks/test_jira.py +++ b/tests/providers/atlassian/jira/hooks/test_jira.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -27,8 +26,8 @@ jira_client_mock = Mock(name="jira_client") -class TestJiraHook(unittest.TestCase): - def setUp(self): +class TestJiraHook: + def setup_method(self): db.merge_conn( Connection( conn_id="jira_default", diff --git a/tests/providers/atlassian/jira/operators/test_jira.py b/tests/providers/atlassian/jira/operators/test_jira.py index c8a4aaf43cdf9..76db8a7d692c2 100644 --- a/tests/providers/atlassian/jira/operators/test_jira.py +++ b/tests/providers/atlassian/jira/operators/test_jira.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -39,8 +38,8 @@ } -class TestJiraOperator(unittest.TestCase): - def setUp(self): +class TestJiraOperator: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG("test_dag_id", default_args=args) self.dag = dag diff --git a/tests/providers/atlassian/jira/sensors/test_jira.py b/tests/providers/atlassian/jira/sensors/test_jira.py index bf814f8c2796a..ecd63ab3acb7b 100644 --- a/tests/providers/atlassian/jira/sensors/test_jira.py +++ b/tests/providers/atlassian/jira/sensors/test_jira.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.models import Connection @@ -46,8 +45,8 @@ class _TicketFields: ) -class TestJiraSensor(unittest.TestCase): - def setUp(self): +class TestJiraSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG("test_dag_id", default_args=args) self.dag = dag diff --git a/tests/providers/celery/sensors/test_celery_queue.py b/tests/providers/celery/sensors/test_celery_queue.py index 6faa08561ebc6..8d09085352adf 100644 --- a/tests/providers/celery/sensors/test_celery_queue.py +++ b/tests/providers/celery/sensors/test_celery_queue.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor -class TestCeleryQueueSensor(unittest.TestCase): - def setUp(self): +class TestCeleryQueueSensor: + def setup_method(self): class TestCeleryqueueSensor(CeleryQueueSensor): def _check_task_id(self, context): return True diff --git a/tests/providers/cloudant/hooks/test_cloudant.py b/tests/providers/cloudant/hooks/test_cloudant.py index 04c527689bb82..22911cf1e0d30 100644 --- a/tests/providers/cloudant/hooks/test_cloudant.py +++ b/tests/providers/cloudant/hooks/test_cloudant.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch import pytest @@ -27,8 +26,8 @@ from airflow.providers.cloudant.hooks.cloudant import CloudantHook -class TestCloudantHook(unittest.TestCase): - def setUp(self): +class TestCloudantHook: + def setup_method(self): self.cloudant_hook = CloudantHook() @patch( diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index 113562e21a4b9..6989337a0b276 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -17,7 +17,6 @@ from __future__ import annotations import json -import unittest from unittest.mock import patch from airflow import DAG @@ -239,8 +238,8 @@ @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") @patch("airflow.utils.context.Context") -class TestSparkKubernetesOperator(unittest.TestCase): - def setUp(self): +class TestSparkKubernetesOperator: + def setup_method(self): db.merge_conn( Connection( conn_id="kubernetes_default_kube_config", diff --git a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py index 1aa7d2d5ac92e..a45a1b8a75cb0 100644 --- a/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/sensors/test_spark_kubernetes.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest.mock import patch import pytest @@ -550,8 +549,8 @@ @patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn") -class TestSparkKubernetesSensor(unittest.TestCase): - def setUp(self): +class TestSparkKubernetesSensor: + def setup_method(self): db.merge_conn(Connection(conn_id="kubernetes_default", conn_type="kubernetes", extra=json.dumps({}))) db.merge_conn( Connection( diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index beca713949203..45b40d8e48fdb 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest import mock import pytest @@ -36,10 +35,8 @@ class NonDbApiHook(BaseHook): pass -class TestDbApiHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestDbApiHook: + def setup_method(self): self.cur = mock.MagicMock( rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] ) diff --git a/tests/providers/common/sql/operators/test_sql.py b/tests/providers/common/sql/operators/test_sql.py index 3741a93ed36ac..1770ed8f5e659 100644 --- a/tests/providers/common/sql/operators/test_sql.py +++ b/tests/providers/common/sql/operators/test_sql.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest from unittest import mock from unittest.mock import MagicMock @@ -43,7 +42,6 @@ from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State -from tests.providers.apache.hive import TestHiveEnvironment class MockHook: @@ -55,7 +53,7 @@ def _get_mock_db_hook(): return MockHook() -class TestSQLExecuteQueryOperator(unittest.TestCase): +class TestSQLExecuteQueryOperator: def _construct_operator(self, sql, **kwargs): dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) return SQLExecuteQueryOperator( @@ -481,8 +479,8 @@ def test_sql_operator_hook_params_biguery(self, mock_get_conn): assert self._operator._hook.location == "us-east1" -class TestCheckOperator(unittest.TestCase): - def setUp(self): +class TestCheckOperator: + def setup_method(self): self._operator = SQLCheckOperator(task_id="test_task", sql="sql", parameters="parameters") @mock.patch.object(SQLCheckOperator, "get_db_hook") @@ -505,8 +503,8 @@ def test_sqlcheckoperator_parameters(self, mock_get_db_hook): mock_get_db_hook.return_value.get_first.assert_called_once_with("sql", "parameters") -class TestValueCheckOperator(unittest.TestCase): - def setUp(self): +class TestValueCheckOperator: + def setup_method(self): self.task_id = "test_task" self.conn_id = "default_conn" @@ -564,7 +562,7 @@ def test_execute_fail(self, mock_get_db_hook): operator.execute(context=MagicMock()) -class TestIntervalCheckOperator(unittest.TestCase): +class TestIntervalCheckOperator: def _construct_operator(self, table, metric_thresholds, ratio_formula, ignore_zero): return SQLIntervalCheckOperator( task_id="test_task", @@ -681,7 +679,7 @@ def returned_row(): operator.execute(context=MagicMock()) -class TestThresholdCheckOperator(unittest.TestCase): +class TestThresholdCheckOperator: def _construct_operator(self, sql, min_threshold, max_threshold): dag = DAG("test_dag", start_date=datetime.datetime(2017, 1, 1)) @@ -757,22 +755,19 @@ def test_fail_min_sql_max_value(self, mock_get_db_hook): operator.execute(context=MagicMock()) -class TestSqlBranch(TestHiveEnvironment, unittest.TestCase): +class TestSqlBranch: """ Test for SQL Branch Operator """ @classmethod - def setUpClass(cls): - super().setUpClass() - + def setup_class(cls): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() session.query(XCom).delete() - def setUp(self): - super().setUp() + def setup_method(self): self.dag = DAG( "sql_branch_operator_test", default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, @@ -782,9 +777,7 @@ def setUp(self): self.branch_2 = EmptyOperator(task_id="branch_2", dag=self.dag) self.branch_3 = None - def tearDown(self): - super().tearDown() - + def teardown_method(self): with create_session() as session: session.query(DagRun).delete() session.query(TI).delete() diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index b77b8da7ba922..895be832b4cd0 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -21,7 +21,6 @@ import json import sys import time -import unittest import aiohttp import pytest @@ -228,13 +227,13 @@ def setup_mock_requests(mock_requests, exception, status_code=500, error_count=N ] -class TestDatabricksHook(unittest.TestCase): +class TestDatabricksHook: """ Tests for DatabricksHook. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = LOGIN @@ -616,11 +615,11 @@ def test_uninstall_libs_on_cluster(self, mock_requests): def test_is_aad_token_valid_returns_true(self): aad_token = {"token": "my_token", "expires_on": int(time.time()) + TOKEN_REFRESH_LEAD_TIME + 10} - self.assertTrue(self.hook._is_aad_token_valid(aad_token)) + assert self.hook._is_aad_token_valid(aad_token) def test_is_aad_token_valid_returns_false(self): aad_token = {"token": "my_token", "expires_on": int(time.time())} - self.assertFalse(self.hook._is_aad_token_valid(aad_token)) + assert not self.hook._is_aad_token_valid(aad_token) @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_list_jobs_success_single_page(self, mock_requests): @@ -756,13 +755,13 @@ def test_connection_failure(self, mock_requests): ) -class TestDatabricksHookToken(unittest.TestCase): +class TestDatabricksHookToken: """ Tests for DatabricksHook when auth is done with token. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.extra = json.dumps({"token": TOKEN, "host": HOST}) @@ -785,13 +784,13 @@ def test_submit_run(self, mock_requests): assert kwargs["auth"].token == TOKEN -class TestDatabricksHookTokenInPassword(unittest.TestCase): +class TestDatabricksHookTokenInPassword: """ Tests for DatabricksHook. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = None @@ -818,7 +817,7 @@ def test_submit_run(self, mock_requests): class TestDatabricksHookTokenWhenNoHostIsProvidedInExtra(TestDatabricksHookToken): @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.extra = json.dumps({"token": TOKEN}) @@ -827,7 +826,7 @@ def setUp(self, session=None): self.hook = DatabricksHook() -class TestRunState(unittest.TestCase): +class TestRunState: def test_is_terminal_true(self): terminal_states = ["TERMINATED", "SKIPPED", "INTERNAL_ERROR"] for state in terminal_states: @@ -874,13 +873,13 @@ def create_aad_token_for_resource(resource: str) -> dict: } -class TestDatabricksHookAadToken(unittest.TestCase): +class TestDatabricksHookAadToken: """ Tests for DatabricksHook when auth is done with AAD token for SP as user inside workspace. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.login = "9ff815a6-4404-4ab8-85cb-cd0e6f879c1d" conn.password = "secret" @@ -911,14 +910,14 @@ def test_submit_run(self, mock_requests): assert kwargs["auth"].token == TOKEN -class TestDatabricksHookAadTokenOtherClouds(unittest.TestCase): +class TestDatabricksHookAadTokenOtherClouds: """ Tests for DatabricksHook when auth is done with AAD token for SP as user inside workspace and using non-global Azure cloud (China, GovCloud, Germany) """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): self.tenant_id = "3ff810a6-5504-4ab8-85cb-cd0e6f879c1d" self.ad_endpoint = "https://login.microsoftonline.de" self.client_id = "9ff815a6-4404-4ab8-85cb-cd0e6f879c1d" @@ -958,13 +957,13 @@ def test_submit_run(self, mock_requests): assert kwargs["auth"].token == TOKEN -class TestDatabricksHookAadTokenSpOutside(unittest.TestCase): +class TestDatabricksHookAadTokenSpOutside: """ Tests for DatabricksHook when auth is done with AAD token for SP outside of workspace. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() self.tenant_id = "3ff810a6-5504-4ab8-85cb-cd0e6f879c1d" self.client_id = "9ff815a6-4404-4ab8-85cb-cd0e6f879c1d" @@ -1011,13 +1010,13 @@ def test_submit_run(self, mock_requests): assert kwargs["headers"]["X-Databricks-Azure-SP-Management-Token"] == TOKEN -class TestDatabricksHookAadTokenManagedIdentity(unittest.TestCase): +class TestDatabricksHookAadTokenManagedIdentity: """ Tests for DatabricksHook when auth is done with AAD leveraging Managed Identity authentication """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.extra = json.dumps( diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index bd52a64c98e40..b3ed8e87ff2c8 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -34,13 +33,13 @@ TOKEN = "token" -class TestDatabricksSqlHookQueryByName(unittest.TestCase): +class TestDatabricksSqlHookQueryByName: """ Tests for DatabricksHook. """ @provide_session - def setUp(self, session=None): + def setup_method(self, method, session=None): conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first() conn.host = HOST conn.login = None diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index c3bca253a6e05..5236e67aa7aa5 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest import mock from unittest.mock import MagicMock @@ -94,7 +93,7 @@ def make_run_with_state_mock( ) -class TestDatabricksSubmitRunOperator(unittest.TestCase): +class TestDatabricksSubmitRunOperator: def test_init_with_notebook_task_named_parameters(self): """ Test the initializer with the named parameters. @@ -437,7 +436,7 @@ def test_no_wait_for_termination(self, db_mock_class): db_mock.get_run.assert_not_called() -class TestDatabricksSubmitRunDeferrableOperator(unittest.TestCase): +class TestDatabricksSubmitRunDeferrableOperator: @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_execute_task_deferred(self, db_mock_class): """ @@ -454,8 +453,8 @@ def test_execute_task_deferred(self, db_mock_class): with pytest.raises(TaskDeferred) as exc: op.execute(None) - self.assertTrue(isinstance(exc.value.trigger, DatabricksExecutionTrigger)) - self.assertEqual(exc.value.method_name, "execute_complete") + assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) + assert exc.value.method_name == "execute_complete" expected = utils.normalise_json_content( {"new_cluster": NEW_CLUSTER, "notebook_task": NOTEBOOK_TASK, "run_name": TASK_ID} @@ -470,7 +469,7 @@ def test_execute_task_deferred(self, db_mock_class): db_mock.submit_run.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - self.assertEqual(RUN_ID, op.run_id) + assert op.run_id == RUN_ID def test_execute_complete_success(self): """ @@ -487,7 +486,7 @@ def test_execute_complete_success(self): } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) - self.assertIsNone(op.execute_complete(context=None, event=event)) + assert op.execute_complete(context=None, event=event) is None @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_execute_complete_failure(self, db_mock_class): @@ -523,7 +522,7 @@ def test_execute_complete_incorrect_event_validation_failure(self): op.execute_complete(context=None, event=event) -class TestDatabricksRunNowOperator(unittest.TestCase): +class TestDatabricksRunNowOperator: def test_init_with_named_parameters(self): """ Test the initializer with the named parameters. @@ -874,7 +873,7 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) -class TestDatabricksRunNowDeferrableOperator(unittest.TestCase): +class TestDatabricksRunNowDeferrableOperator: @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_execute_task_deferred(self, db_mock_class): """ @@ -888,8 +887,8 @@ def test_execute_task_deferred(self, db_mock_class): with pytest.raises(TaskDeferred) as exc: op.execute(None) - self.assertTrue(isinstance(exc.value.trigger, DatabricksExecutionTrigger)) - self.assertEqual(exc.value.method_name, "execute_complete") + assert isinstance(exc.value.trigger, DatabricksExecutionTrigger) + assert exc.value.method_name == "execute_complete" expected = utils.normalise_json_content( { @@ -910,7 +909,7 @@ def test_execute_task_deferred(self, db_mock_class): db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) - self.assertEqual(RUN_ID, op.run_id) + assert op.run_id == RUN_ID def test_execute_complete_success(self): """ @@ -924,7 +923,7 @@ def test_execute_complete_success(self): } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) - self.assertIsNone(op.execute_complete(context=None, event=event)) + assert op.execute_complete(context=None, event=event) is None @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_execute_complete_failure(self, db_mock_class): diff --git a/tests/providers/databricks/operators/test_databricks_repos.py b/tests/providers/databricks/operators/test_databricks_repos.py index 3c593b82fdd30..9e8770d3d9653 100644 --- a/tests/providers/databricks/operators/test_databricks_repos.py +++ b/tests/providers/databricks/operators/test_databricks_repos.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -33,7 +32,7 @@ DEFAULT_CONN_ID = "databricks_default" -class TestDatabricksReposUpdateOperator(unittest.TestCase): +class TestDatabricksReposUpdateOperator: @mock.patch("airflow.providers.databricks.operators.databricks_repos.DatabricksHook") def test_update_with_id(self, db_mock_class): """ @@ -98,7 +97,7 @@ def test_init_exception(self): DatabricksReposUpdateOperator(task_id=TASK_ID, repo_id="123") -class TestDatabricksReposDeleteOperator(unittest.TestCase): +class TestDatabricksReposDeleteOperator: @mock.patch("airflow.providers.databricks.operators.databricks_repos.DatabricksHook") def test_delete_with_id(self, db_mock_class): """ @@ -153,7 +152,7 @@ def test_init_exception(self): DatabricksReposDeleteOperator(task_id=TASK_ID) -class TestDatabricksReposCreateOperator(unittest.TestCase): +class TestDatabricksReposCreateOperator: @mock.patch("airflow.providers.databricks.operators.databricks_repos.DatabricksHook") def test_create_plus_checkout(self, db_mock_class): """ diff --git a/tests/providers/databricks/operators/test_databricks_sql.py b/tests/providers/databricks/operators/test_databricks_sql.py index 0064c0f7f6d5f..420376dc3d148 100644 --- a/tests/providers/databricks/operators/test_databricks_sql.py +++ b/tests/providers/databricks/operators/test_databricks_sql.py @@ -19,7 +19,6 @@ import os import tempfile -import unittest from unittest import mock import pytest @@ -38,7 +37,7 @@ COPY_FILE_LOCATION = "s3://my-bucket/jsonData" -class TestDatabricksSqlOperator(unittest.TestCase): +class TestDatabricksSqlOperator: @mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") def test_exec_success(self, db_mock_class): """ @@ -113,7 +112,7 @@ def test_exec_write_file(self, db_mock_class): ) -class TestDatabricksSqlCopyIntoOperator(unittest.TestCase): +class TestDatabricksSqlCopyIntoOperator: def test_copy_with_files(self): op = DatabricksCopyIntoOperator( file_location=COPY_FILE_LOCATION, diff --git a/tests/providers/databricks/utils/databricks.py b/tests/providers/databricks/utils/databricks.py index b918a2b0f2c63..7619bcb8ad07f 100644 --- a/tests/providers/databricks/utils/databricks.py +++ b/tests/providers/databricks/utils/databricks.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.exceptions import AirflowException @@ -29,7 +27,7 @@ RUN_PAGE_URL = "run-page-url" -class TestDatabricksOperatorSharedFunctions(unittest.TestCase): +class TestDatabricksOperatorSharedFunctions: def test_normalise_json_content(self): test_json = { "test_bool": True, @@ -56,7 +54,7 @@ def test_validate_trigger_event_success(self): "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), } - self.assertIsNone(validate_trigger_event(event)) + assert validate_trigger_event(event) is None def test_validate_trigger_event_failure(self): event = {} diff --git a/tests/providers/datadog/hooks/test_datadog.py b/tests/providers/datadog/hooks/test_datadog.py index 5b026ff1c0bbd..b5a4d5aaf2290 100644 --- a/tests/providers/datadog/hooks/test_datadog.py +++ b/tests/providers/datadog/hooks/test_datadog.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest import mock import pytest @@ -46,20 +45,20 @@ DEVICE_NAME = "device-name" -class TestDatadogHook(unittest.TestCase): - @mock.patch("airflow.providers.datadog.hooks.datadog.initialize") - @mock.patch("airflow.providers.datadog.hooks.datadog.DatadogHook.get_connection") - def setUp(self, mock_get_connection, mock_initialize): - mock_get_connection.return_value = Connection( - extra=json.dumps( - { - "app_key": APP_KEY, - "api_key": API_KEY, - "api_host": API_HOST, - } - ) - ) - self.hook = DatadogHook() +class TestDatadogHook: + def setup_method(self): + with mock.patch("airflow.providers.datadog.hooks.datadog.initialize"): + with mock.patch("airflow.providers.datadog.hooks.datadog.DatadogHook.get_connection") as m: + m.return_value = Connection( + extra=json.dumps( + { + "app_key": APP_KEY, + "api_key": API_KEY, + "api_host": API_HOST, + } + ) + ) + self.hook = DatadogHook() @mock.patch("airflow.providers.datadog.hooks.datadog.initialize") @mock.patch("airflow.providers.datadog.hooks.datadog.DatadogHook.get_connection") diff --git a/tests/providers/datadog/sensors/test_datadog.py b/tests/providers/datadog/sensors/test_datadog.py index ae33cc136d233..e8c146f1c955d 100644 --- a/tests/providers/datadog/sensors/test_datadog.py +++ b/tests/providers/datadog/sensors/test_datadog.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest.mock import patch from airflow.models import Connection @@ -63,8 +62,8 @@ zero_events: list = [] -class TestDatadogSensor(unittest.TestCase): - def setUp(self): +class TestDatadogSensor: + def setup_method(self): db.merge_conn( Connection( conn_id="datadog_default", diff --git a/tests/providers/dingding/hooks/test_dingding.py b/tests/providers/dingding/hooks/test_dingding.py index 202c74a410a5f..745fc65bf5397 100644 --- a/tests/providers/dingding/hooks/test_dingding.py +++ b/tests/providers/dingding/hooks/test_dingding.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest import pytest @@ -27,10 +26,10 @@ from airflow.utils import db -class TestDingdingHook(unittest.TestCase): +class TestDingdingHook: conn_id = "dingding_conn_id_test" - def setUp(self): + def setup_method(self): db.merge_conn( Connection( conn_id=self.conn_id, diff --git a/tests/providers/dingding/operators/test_dingding.py b/tests/providers/dingding/operators/test_dingding.py index 3858316ef9f66..d2b25c242e2ce 100644 --- a/tests/providers/dingding/operators/test_dingding.py +++ b/tests/providers/dingding/operators/test_dingding.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.models.dag import DAG @@ -27,7 +26,7 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -class TestDingdingOperator(unittest.TestCase): +class TestDingdingOperator: _config = { "dingding_conn_id": "dingding_default", "message_type": "text", @@ -36,7 +35,7 @@ class TestDingdingOperator(unittest.TestCase): "at_all": False, } - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/discord/hooks/test_discord_webhook.py b/tests/providers/discord/hooks/test_discord_webhook.py index 40dab6ec73a25..d1e0406ec87f4 100644 --- a/tests/providers/discord/hooks/test_discord_webhook.py +++ b/tests/providers/discord/hooks/test_discord_webhook.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest import pytest @@ -28,7 +27,7 @@ from airflow.utils import db -class TestDiscordWebhookHook(unittest.TestCase): +class TestDiscordWebhookHook: _config = { "http_conn_id": "default-discord-webhook", @@ -49,7 +48,7 @@ class TestDiscordWebhookHook(unittest.TestCase): expected_payload = json.dumps(expected_payload_dict) - def setUp(self): + def setup_method(self): db.merge_conn( Connection( conn_id="default-discord-webhook", diff --git a/tests/providers/discord/operators/test_discord_webhook.py b/tests/providers/discord/operators/test_discord_webhook.py index 3e7b2352908d7..27cbe7d6d6532 100644 --- a/tests/providers/discord/operators/test_discord_webhook.py +++ b/tests/providers/discord/operators/test_discord_webhook.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - from airflow.models.dag import DAG from airflow.providers.discord.operators.discord_webhook import DiscordWebhookOperator from airflow.utils import timezone @@ -26,7 +24,7 @@ DEFAULT_DATE = timezone.datetime(2018, 1, 1) -class TestDiscordWebhookOperator(unittest.TestCase): +class TestDiscordWebhookOperator: _config = { "http_conn_id": "discord-webhook-default", "webhook_endpoint": "webhooks/11111/some-discord-token_111", @@ -37,7 +35,7 @@ class TestDiscordWebhookOperator(unittest.TestCase): "proxy": "https://proxy.proxy.com:8888", } - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/docker/hooks/test_docker.py b/tests/providers/docker/hooks/test_docker.py index 18c4c27f1ccf1..b7d24aa908a99 100644 --- a/tests/providers/docker/hooks/test_docker.py +++ b/tests/providers/docker/hooks/test_docker.py @@ -17,24 +17,19 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest from airflow.exceptions import AirflowException from airflow.models import Connection +from airflow.providers.docker.hooks.docker import DockerHook from airflow.utils import db -try: - from airflow.providers.docker.hooks.docker import DockerHook -except ImportError: - pass - @mock.patch("airflow.providers.docker.hooks.docker.APIClient", autospec=True) -class TestDockerHook(unittest.TestCase): - def setUp(self): +class TestDockerHook: + def setup_method(self): db.merge_conn( Connection( conn_id="docker_default", diff --git a/tests/providers/docker/operators/test_docker.py b/tests/providers/docker/operators/test_docker.py index c3506eed3b841..f17c6a01ebc0b 100644 --- a/tests/providers/docker/operators/test_docker.py +++ b/tests/providers/docker/operators/test_docker.py @@ -18,31 +18,24 @@ from __future__ import annotations import logging -import unittest from unittest import mock from unittest.mock import call import pytest +from docker import APIClient from docker.constants import DEFAULT_TIMEOUT_SECONDS from docker.errors import APIError +from docker.types import DeviceRequest, LogConfig, Mount from airflow.exceptions import AirflowException - -try: - from docker import APIClient - from docker.types import DeviceRequest, LogConfig, Mount - - from airflow.providers.docker.hooks.docker import DockerHook - from airflow.providers.docker.operators.docker import DockerOperator -except ImportError: - pass - +from airflow.providers.docker.hooks.docker import DockerHook +from airflow.providers.docker.operators.docker import DockerOperator TEMPDIR_MOCK_RETURN_VALUE = "/mkdtemp" -class TestDockerOperator(unittest.TestCase): - def setUp(self): +class TestDockerOperator: + def setup_method(self): self.tempdir_patcher = mock.patch("airflow.providers.docker.operators.docker.TemporaryDirectory") self.tempdir_mock = self.tempdir_patcher.start() self.tempdir_mock.return_value.__enter__.return_value = TEMPDIR_MOCK_RETURN_VALUE @@ -81,7 +74,7 @@ def dotenv_mock_return_value(**kwargs): self.dotenv_mock = self.dotenv_patcher.start() self.dotenv_mock.side_effect = dotenv_mock_return_value - def tearDown(self) -> None: + def teardown_method(self) -> None: self.tempdir_patcher.stop() self.client_class_patcher.stop() self.dotenv_patcher.stop() @@ -241,7 +234,7 @@ def test_execute_no_temp_dir(self): self.dotenv_mock.assert_called_once_with(stream="ENV=FILE\nVAR=VALUE") stringio_patcher.stop() - def test_execute_fallback_temp_dir(self): + def test_execute_fallback_temp_dir(self, caplog): self.client_mock.create_container.side_effect = [ APIError(message="wrong path: " + TEMPDIR_MOCK_RETURN_VALUE), {"Id": "some_id"}, @@ -270,12 +263,16 @@ def test_execute_fallback_temp_dir(self): container_name="test_container", tty=True, ) - with self.assertLogs(operator.log, level=logging.WARNING) as captured: + caplog.clear() + with caplog.at_level(logging.WARNING, logger=operator.log.name): operator.execute(None) - assert ( - "WARNING:airflow.task.operators:Using remote engine or docker-in-docker " - "and mounting temporary volume from host is not supported" in captured.output[0] + warning_message = ( + "Using remote engine or docker-in-docker and mounting temporary volume from host " + "is not supported. Falling back to `mount_tmp_dir=False` mode. " + "You can set `mount_tmp_dir` parameter to False to disable mounting and remove the warning" ) + assert warning_message in caplog.messages + self.client_class_mock.assert_called_once_with( base_url="unix://var/run/docker.sock", tls=None, version="1.19", timeout=DEFAULT_TIMEOUT_SECONDS ) diff --git a/tests/providers/docker/operators/test_docker_swarm.py b/tests/providers/docker/operators/test_docker_swarm.py index ed7217b4eeb0a..d12be721f3652 100644 --- a/tests/providers/docker/operators/test_docker_swarm.py +++ b/tests/providers/docker/operators/test_docker_swarm.py @@ -17,19 +17,17 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest from docker import APIClient, types from docker.constants import DEFAULT_TIMEOUT_SECONDS -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator -class TestDockerSwarmOperator(unittest.TestCase): +class TestDockerSwarmOperator: @mock.patch("airflow.providers.docker.operators.docker.APIClient") @mock.patch("airflow.providers.docker.operators.docker_swarm.types") def test_execute(self, types_mock, client_class_mock): @@ -162,10 +160,10 @@ def test_no_auto_remove(self, types_mock, client_class_mock): client_mock.remove_service.call_count == 0 ), "Docker service being removed even when `auto_remove` set to `False`" - @parameterized.expand([("failed",), ("shutdown",), ("rejected",), ("orphaned",), ("remove",)]) + @pytest.mark.parametrize("status", ["failed", "shutdown", "rejected", "orphaned", "remove"]) @mock.patch("airflow.providers.docker.operators.docker.APIClient") @mock.patch("airflow.providers.docker.operators.docker_swarm.types") - def test_non_complete_service_raises_error(self, status, types_mock, client_class_mock): + def test_non_complete_service_raises_error(self, types_mock, client_class_mock, status): mock_obj = mock.Mock() diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py index b853b18f42902..c80ccb890b129 100644 --- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py +++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py @@ -17,9 +17,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock +import pytest from elasticsearch import Elasticsearch from airflow.models import Connection @@ -30,7 +30,7 @@ ) -class TestElasticsearchHook(unittest.TestCase): +class TestElasticsearchHook: def test_throws_warning(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() @@ -38,7 +38,7 @@ def test_throws_warning(self): conn = self.conn self.connection = Connection(host="localhost", port=9200, schema="http") - with self.assertWarns(DeprecationWarning): + with pytest.warns(DeprecationWarning): class UnitTestElasticsearchHook(ElasticsearchHook): conn_name_attr = "test_conn_id" @@ -49,10 +49,8 @@ def get_conn(self): self.db_hook = UnitTestElasticsearchHook() -class TestElasticsearchSQLHookConn(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestElasticsearchSQLHookConn: + def setup_method(self): self.connection = Connection(host="localhost", port=9200, schema="http") class UnitTestElasticsearchHook(ElasticsearchSQLHook): @@ -69,10 +67,8 @@ def test_get_conn(self, mock_connect): mock_connect.assert_called_with(host="localhost", port=9200, scheme="http", user=None, password=None) -class TestElasticsearcSQLhHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestElasticsearchSQLHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur diff --git a/tests/providers/exasol/hooks/test_exasol.py b/tests/providers/exasol/hooks/test_exasol.py index 8345b36b25175..f208450826487 100644 --- a/tests/providers/exasol/hooks/test_exasol.py +++ b/tests/providers/exasol/hooks/test_exasol.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from unittest import mock import pytest @@ -27,10 +26,8 @@ from airflow.providers.exasol.hooks.exasol import ExasolHook -class TestExasolHookConn(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestExasolHookConn: + def setup_method(self): self.connection = models.Connection( login="login", password="password", @@ -66,10 +63,8 @@ def test_get_conn_extra_args(self, mock_pyexasol): assert kwargs["encryption"] is True -class TestExasolHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestExasolHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.execute.return_value = self.cur diff --git a/tests/providers/exasol/operators/test_exasol.py b/tests/providers/exasol/operators/test_exasol.py index 6f1d00c999a96..ffc06f6eb601a 100644 --- a/tests/providers/exasol/operators/test_exasol.py +++ b/tests/providers/exasol/operators/test_exasol.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.exasol.operators.exasol import ExasolOperator -class TestExasol(unittest.TestCase): +class TestExasol: @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") def test_overwrite_autocommit(self, mock_get_db_hook): operator = ExasolOperator(task_id="TEST", sql="SELECT 1", autocommit=True) diff --git a/tests/providers/ftp/hooks/test_ftp.py b/tests/providers/ftp/hooks/test_ftp.py index e62344ee699d3..dcb2e4ac512c9 100644 --- a/tests/providers/ftp/hooks/test_ftp.py +++ b/tests/providers/ftp/hooks/test_ftp.py @@ -18,15 +18,13 @@ from __future__ import annotations import io -import unittest from unittest import mock from airflow.providers.ftp.hooks import ftp as fh -class TestFTPHook(unittest.TestCase): - def setUp(self): - super().setUp() +class TestFTPHook: + def setup_method(self): self.path = "/some/path" self.conn_mock = mock.MagicMock(name="conn") self.get_conn_orig = fh.FTPHook.get_conn @@ -37,9 +35,8 @@ def _get_conn_mock(hook): fh.FTPHook.get_conn = _get_conn_mock - def tearDown(self): + def teardown_method(self): fh.FTPHook.get_conn = self.get_conn_orig - super().tearDown() def test_close_conn(self): ftp_hook = fh.FTPHook() @@ -137,9 +134,8 @@ def test_connection_failure(self): assert msg == "Test" -class TestIntegrationFTPHook(unittest.TestCase): - def setUp(self): - super().setUp() +class TestIntegrationFTPHook: + def setup_method(self): from airflow.models import Connection from airflow.utils import db diff --git a/tests/providers/ftp/sensors/test_ftp.py b/tests/providers/ftp/sensors/test_ftp.py index 59a02ad543d3b..e79b46526aaee 100644 --- a/tests/providers/ftp/sensors/test_ftp.py +++ b/tests/providers/ftp/sensors/test_ftp.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from ftplib import error_perm from unittest import mock @@ -27,7 +26,7 @@ from airflow.providers.ftp.sensors.ftp import FTPSensor -class TestFTPSensor(unittest.TestCase): +class TestFTPSensor: @mock.patch("airflow.providers.ftp.sensors.ftp.FTPHook", spec=FTPHook) def test_poke(self, mock_hook): op = FTPSensor(path="foobar.json", ftp_conn_id="bob_ftp", task_id="test_task") diff --git a/tests/providers/grpc/operators/test_grpc.py b/tests/providers/grpc/operators/test_grpc.py index e69a23bc3cd6b..62540bbfeae22 100644 --- a/tests/providers/grpc/operators/test_grpc.py +++ b/tests/providers/grpc/operators/test_grpc.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.grpc.operators.grpc import GrpcOperator @@ -30,7 +29,7 @@ def stream_call(self, data): pass -class TestGrpcOperator(unittest.TestCase): +class TestGrpcOperator: def custom_conn_func(self, connection): pass diff --git a/tests/providers/hashicorp/_internal_client/test_vault_client.py b/tests/providers/hashicorp/_internal_client/test_vault_client.py index 15aa1f19a8bc4..1bab652dc01d9 100644 --- a/tests/providers/hashicorp/_internal_client/test_vault_client.py +++ b/tests/providers/hashicorp/_internal_client/test_vault_client.py @@ -17,7 +17,6 @@ from __future__ import annotations from unittest import mock -from unittest.case import TestCase from unittest.mock import mock_open, patch import pytest @@ -26,7 +25,7 @@ from airflow.providers.hashicorp._internal_client.vault_client import _VaultClient -class TestVaultClient(TestCase): +class TestVaultClient: @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_version_wrong(self, mock_hvac): mock_client = mock.MagicMock() diff --git a/tests/providers/hashicorp/hooks/test_vault.py b/tests/providers/hashicorp/hooks/test_vault.py index c91b1a2d1f1dd..bad350fe65528 100644 --- a/tests/providers/hashicorp/hooks/test_vault.py +++ b/tests/providers/hashicorp/hooks/test_vault.py @@ -17,17 +17,15 @@ from __future__ import annotations from unittest import mock -from unittest.case import TestCase from unittest.mock import PropertyMock, mock_open, patch import pytest from hvac.exceptions import VaultError -from parameterized import parameterized from airflow.providers.hashicorp.hooks.vault import VaultHook -class TestVaultHook(TestCase): +class TestVaultHook: @staticmethod def get_mock_connection( conn_type="vault", schema="secret", host="localhost", port=8180, user="user", password="pass" @@ -58,15 +56,16 @@ def test_version_not_int(self, mock_hvac, mock_get_connection): with pytest.raises(VaultError, match="The version is not an int: text"): VaultHook(**kwargs) - @parameterized.expand( + @pytest.mark.parametrize( + "version, expected_version", [ ("2", 2), (1, 1), - ] + ], ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_version(self, version, expected_version, mock_hvac, mock_get_connection): + def test_version(self, mock_hvac, mock_get_connection, version, expected_version): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection() @@ -156,16 +155,17 @@ def test_version_one_dejson(self, mock_hvac, mock_get_connection): test_hook = VaultHook(**kwargs) assert 1 == test_hook.vault_client.kv_engine_version - @parameterized.expand( + @pytest.mark.parametrize( + "protocol, expected_url", [ ("vaults", "https://localhost:8180"), ("http", "http://localhost:8180"), ("https", "https://localhost:8180"), - ] + ], ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_protocol(self, protocol, expected_url, mock_hvac, mock_get_connection): + def test_protocol(self, mock_hvac, mock_get_connection, protocol, expected_url): mock_client = mock.MagicMock() mock_hvac.Client.return_value = mock_client mock_connection = self.get_mock_connection(conn_type=protocol) @@ -1154,15 +1154,16 @@ def test_create_or_update_secret_v2_cas(self, mock_hvac, mock_get_connection): mount_point="secret", secret_path="path", secret={"key": "value"}, cas=10 ) - @parameterized.expand( + @pytest.mark.parametrize( + "method, expected_method", [ (None, None), ("post", "post"), - ] + ], ) @mock.patch("airflow.providers.hashicorp.hooks.vault.VaultHook.get_connection") @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") - def test_create_or_update_secret_v1(self, method, expected_method, mock_hvac, mock_get_connection): + def test_create_or_update_secret_v1(self, mock_hvac, mock_get_connection, method, expected_method): mock_connection = self.get_mock_connection() mock_get_connection.return_value = mock_connection mock_client = mock.MagicMock() diff --git a/tests/providers/hashicorp/secrets/test_vault.py b/tests/providers/hashicorp/secrets/test_vault.py index 82c54cbeabd99..a29e6dc21e8c3 100644 --- a/tests/providers/hashicorp/secrets/test_vault.py +++ b/tests/providers/hashicorp/secrets/test_vault.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from unittest import TestCase, mock +from unittest import mock import pytest from hvac.exceptions import InvalidPath, VaultError @@ -24,7 +24,7 @@ from airflow.providers.hashicorp.secrets.vault import VaultBackend -class TestVaultSecrets(TestCase): +class TestVaultSecrets: @mock.patch("airflow.providers.hashicorp._internal_client.vault_client.hvac") def test_get_conn_uri(self, mock_hvac): mock_client = mock.MagicMock() diff --git a/tests/providers/http/hooks/test_http.py b/tests/providers/http/hooks/test_http.py index 9f3efd340f55f..03bf09075c3c8 100644 --- a/tests/providers/http/hooks/test_http.py +++ b/tests/providers/http/hooks/test_http.py @@ -19,16 +19,13 @@ import json import os -import unittest from collections import OrderedDict from http import HTTPStatus from unittest import mock import pytest import requests -import requests_mock import tenacity -from parameterized import parameterized from requests.adapters import Response from airflow.exceptions import AirflowException @@ -44,10 +41,12 @@ def get_airflow_connection_with_port(unused_conn_id=None): return Connection(conn_id="http_default", conn_type="http", host="test.com", port=1234) -class TestHttpHook(unittest.TestCase): +class TestHttpHook: """Test get, post and raise_for_status""" - def setUp(self): + def setup_method(self): + import requests_mock + session = requests.Session() adapter = requests_mock.Adapter() session.mount("mock", adapter) @@ -55,18 +54,17 @@ def setUp(self): self.get_lowercase_hook = HttpHook(method="get") self.post_hook = HttpHook(method="POST") - @requests_mock.mock() - def test_raise_for_status_with_200(self, m): - - m.get("http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK") + def test_raise_for_status_with_200(self, requests_mock): + requests_mock.get( + "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK" + ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): resp = self.get_hook.run("v1/test") assert resp.text == '{"status":{"status": 200}}' - @requests_mock.mock() - @mock.patch("requests.Session") @mock.patch("requests.Request") - def test_get_request_with_port(self, mock_requests, request_mock, mock_session): + @mock.patch("requests.Session") + def test_get_request_with_port(self, mock_session, mock_request): from requests.exceptions import MissingSchema with mock.patch( @@ -80,16 +78,14 @@ def test_get_request_with_port(self, mock_requests, request_mock, mock_session): except MissingSchema: pass - request_mock.assert_called_once_with( + mock_request.assert_called_once_with( mock.ANY, expected_url, headers=mock.ANY, params=mock.ANY ) - request_mock.reset_mock() + mock_request.reset_mock() - @requests_mock.mock() - def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m): - - m.get( + def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock): + requests_mock.get( "http://test:8080/v1/test", status_code=404, text='{"status":{"status": 404}}', @@ -100,17 +96,15 @@ def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m) resp = self.get_hook.run("v1/test", extra_options={"check_response": False}) assert resp.text == '{"status":{"status": 404}}' - @requests_mock.mock() - def test_hook_contains_header_from_extra_field(self, mock_requests): + def test_hook_contains_header_from_extra_field(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): expected_conn = get_airflow_connection() conn = self.get_hook.get_conn() assert dict(conn.headers, **json.loads(expected_conn.extra)) == conn.headers assert conn.headers.get("bareer") == "test" - @requests_mock.mock() @mock.patch("requests.Request") - def test_hook_with_method_in_lowercase(self, mock_requests, request_mock): + def test_hook_with_method_in_lowercase(self, mock_requests): from requests.exceptions import InvalidURL, MissingSchema with mock.patch( @@ -121,27 +115,23 @@ def test_hook_with_method_in_lowercase(self, mock_requests, request_mock): self.get_lowercase_hook.run("v1/test", data=data) except (MissingSchema, InvalidURL): pass - request_mock.assert_called_once_with(mock.ANY, mock.ANY, headers=mock.ANY, params=data) + mock_requests.assert_called_once_with(mock.ANY, mock.ANY, headers=mock.ANY, params=data) - @requests_mock.mock() - def test_hook_uses_provided_header(self, mock_requests): + def test_hook_uses_provided_header(self): conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"}) assert conn.headers.get("bareer") == "newT0k3n" - @requests_mock.mock() - def test_hook_has_no_header_from_extra(self, mock_requests): + def test_hook_has_no_header_from_extra(self): conn = self.get_hook.get_conn() assert conn.headers.get("bareer") is None - @requests_mock.mock() - def test_hooks_header_from_extra_is_overridden(self, mock_requests): + def test_hooks_header_from_extra_is_overridden(self): with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"}) assert conn.headers.get("bareer") == "newT0k3n" - @requests_mock.mock() - def test_post_request(self, mock_requests): - mock_requests.post( + def test_post_request(self, requests_mock): + requests_mock.post( "http://test:8080/v1/test", status_code=200, text='{"status":{"status": 200}}', reason="OK" ) @@ -149,9 +139,8 @@ def test_post_request(self, mock_requests): resp = self.post_hook.run("v1/test") assert resp.status_code == 200 - @requests_mock.mock() - def test_post_request_with_error_code(self, mock_requests): - mock_requests.post( + def test_post_request_with_error_code(self, requests_mock): + requests_mock.post( "http://test:8080/v1/test", status_code=418, text='{"status":{"status": 418}}', @@ -162,9 +151,8 @@ def test_post_request_with_error_code(self, mock_requests): with pytest.raises(AirflowException): self.post_hook.run("v1/test") - @requests_mock.mock() - def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, mock_requests): - mock_requests.post( + def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, requests_mock): + requests_mock.post( "http://test:8080/v1/test", status_code=418, text='{"status":{"status": 418}}', @@ -193,10 +181,9 @@ def send_and_raise(unused_request, **kwargs): self.get_hook.run_with_advanced_retry(endpoint="v1/test", _retry_args=retry_args) assert self.get_hook._retry_obj.stop.max_attempt_number + 1 == mocked_session.call_count - @requests_mock.mock() - def test_run_with_advanced_retry(self, m): + def test_run_with_advanced_retry(self, requests_mock): - m.get("http://test:8080/v1/test", status_code=200, reason="OK") + requests_mock.get("http://test:8080/v1/test", status_code=200, reason="OK") retry_args = dict( wait=tenacity.wait_none(), @@ -266,20 +253,14 @@ def test_connection_without_host(self, mock_get_connection): hook.get_conn({}) assert hook.base_url == "http://" - @parameterized.expand( - [ - "GET", - "POST", - ] - ) - @requests_mock.mock() - def test_json_request(self, method, mock_requests): + @pytest.mark.parametrize("method", ["GET", "POST"]) + def test_json_request(self, method, requests_mock): obj1 = {"a": 1, "b": "abc", "c": [1, 2, {"d": 10}]} def match_obj1(request): return request.json() == obj1 - mock_requests.request(method=method, url="//test:8080/v1/test", additional_matcher=match_obj1) + requests_mock.request(method=method, url="//test:8080/v1/test", additional_matcher=match_obj1) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): # will raise NoMockAddress exception if obj1 != request.json() @@ -357,17 +338,17 @@ def test_verify_false_parameter_overwrites_set_requests_ca_bundle_env_var(self, verify=False, ) - @requests_mock.mock() - def test_connection_success(self, m): - m.get("http://test:8080", status_code=200, json={"status": {"status": 200}}, reason="OK") + def test_connection_success(self, requests_mock): + requests_mock.get("http://test:8080", status_code=200, json={"status": {"status": 200}}, reason="OK") with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): status, msg = self.get_hook.test_connection() assert status is True assert msg == "Connection successfully tested" - @requests_mock.mock() - def test_connection_failure(self, m): - m.get("http://test:8080", status_code=500, json={"message": "internal server error"}, reason="NOT_OK") + def test_connection_failure(self, requests_mock): + requests_mock.get( + "http://test:8080", status_code=500, json={"message": "internal server error"}, reason="NOT_OK" + ) with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection): status, msg = self.get_hook.test_connection() assert status is False diff --git a/tests/providers/http/operators/test_http.py b/tests/providers/http/operators/test_http.py index 227ea03b273a6..fca910a7680eb 100644 --- a/tests/providers/http/operators/test_http.py +++ b/tests/providers/http/operators/test_http.py @@ -17,26 +17,23 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.providers.http.operators.http import SimpleHttpOperator @mock.patch.dict("os.environ", AIRFLOW_CONN_HTTP_EXAMPLE="http://www.example.com") -class TestSimpleHttpOp(unittest.TestCase): - @requests_mock.mock() - def test_response_in_logs(self, m): +class TestSimpleHttpOp: + def test_response_in_logs(self, requests_mock): """ Test that when using SimpleHttpOperator with 'GET', the log contains 'Example Domain' in it """ - m.get("http://www.example.com", text="Example.com fake response") + requests_mock.get("http://www.example.com", text="Example.com fake response") operator = SimpleHttpOperator( task_id="test_HTTP_op", method="GET", @@ -48,8 +45,7 @@ def test_response_in_logs(self, m): result = operator.execute("Example.com fake response") assert result == "Example.com fake response" - @requests_mock.mock() - def test_response_in_logs_after_failed_check(self, m): + def test_response_in_logs_after_failed_check(self, requests_mock): """ Test that when using SimpleHttpOperator with log_response=True, the response is logged even if request_check fails @@ -58,7 +54,7 @@ def test_response_in_logs_after_failed_check(self, m): def response_check(response): return response.text != "invalid response" - m.get("http://www.example.com", text="invalid response") + requests_mock.get("http://www.example.com", text="invalid response") operator = SimpleHttpOperator( task_id="test_HTTP_op", method="GET", @@ -74,9 +70,8 @@ def response_check(response): calls = [mock.call("Calling HTTP method"), mock.call("invalid response")] mock_info.assert_has_calls(calls, any_order=True) - @requests_mock.mock() - def test_filters_response(self, m): - m.get("http://www.example.com", json={"value": 5}) + def test_filters_response(self, requests_mock): + requests_mock.get("http://www.example.com", json={"value": 5}) operator = SimpleHttpOperator( task_id="test_HTTP_op", method="GET", diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py index 7b801e28fc21e..0463a80a04fa6 100644 --- a/tests/providers/http/sensors/test_http.py +++ b/tests/providers/http/sensors/test_http.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import patch @@ -206,8 +205,8 @@ def mount(self, prefix, adapter): pass -class TestHttpOpSensor(unittest.TestCase): - def setUp(self): +class TestHttpOpSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE_ISO} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag diff --git a/tests/providers/imap/hooks/test_imap.py b/tests/providers/imap/hooks/test_imap.py index 1912c37c5efaf..c2be420be6e59 100644 --- a/tests/providers/imap/hooks/test_imap.py +++ b/tests/providers/imap/hooks/test_imap.py @@ -19,7 +19,6 @@ import imaplib import json -import unittest from unittest.mock import Mock, mock_open, patch import pytest @@ -61,8 +60,8 @@ def _create_fake_imap(mock_imaplib, with_mail=False, attachment_name="test1.csv" return mock_conn -class TestImapHook(unittest.TestCase): - def setUp(self): +class TestImapHook: + def setup_method(self): db.merge_conn( Connection( conn_id="imap_default", diff --git a/tests/providers/imap/sensors/test_imap_attachment.py b/tests/providers/imap/sensors/test_imap_attachment.py index 6fd9db8b19837..41cda774fcf46 100644 --- a/tests/providers/imap/sensors/test_imap_attachment.py +++ b/tests/providers/imap/sensors/test_imap_attachment.py @@ -17,16 +17,15 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch -from parameterized import parameterized +import pytest from airflow.providers.imap.sensors.imap_attachment import ImapAttachmentSensor -class TestImapAttachmentSensor(unittest.TestCase): - def setUp(self): +class TestImapAttachmentSensor: + def setup_method(self): self.kwargs = dict( attachment_name="test_file", check_regex=False, @@ -36,9 +35,9 @@ def setUp(self): dag=None, ) - @parameterized.expand([(True,), (False,)]) + @pytest.mark.parametrize("has_attachment_return_value", [True, False]) @patch("airflow.providers.imap.sensors.imap_attachment.ImapHook") - def test_poke(self, has_attachment_return_value, mock_imap_hook): + def test_poke(self, mock_imap_hook, has_attachment_return_value): mock_imap_hook.return_value.__enter__ = Mock(return_value=mock_imap_hook) mock_imap_hook.has_mail_attachment.return_value = has_attachment_return_value diff --git a/tests/providers/influxdb/hooks/test_influxdb.py b/tests/providers/influxdb/hooks/test_influxdb.py index 7b0f2d7b78fab..630fe26263680 100644 --- a/tests/providers/influxdb/hooks/test_influxdb.py +++ b/tests/providers/influxdb/hooks/test_influxdb.py @@ -16,16 +16,14 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.models import Connection from airflow.providers.influxdb.hooks.influxdb import InfluxDBHook -class TestInfluxDbHook(unittest.TestCase): - def setUp(self): - super().setUp() +class TestInfluxDbHook: + def setup_method(self): self.influxdb_hook = InfluxDBHook() extra = {} extra["token"] = "123456789" diff --git a/tests/providers/influxdb/operators/test_influxdb.py b/tests/providers/influxdb/operators/test_influxdb.py index 2ae8807ac9e86..f39e99e12a403 100644 --- a/tests/providers/influxdb/operators/test_influxdb.py +++ b/tests/providers/influxdb/operators/test_influxdb.py @@ -16,13 +16,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.influxdb.operators.influxdb import InfluxDBOperator -class TestInfluxDBOperator(unittest.TestCase): +class TestInfluxDBOperator: @mock.patch("airflow.providers.influxdb.operators.influxdb.InfluxDBHook") def test_influxdb_operator_test(self, mock_hook): diff --git a/tests/providers/jdbc/operators/test_jdbc.py b/tests/providers/jdbc/operators/test_jdbc.py index b9339ce584850..e027bdb96df01 100644 --- a/tests/providers/jdbc/operators/test_jdbc.py +++ b/tests/providers/jdbc/operators/test_jdbc.py @@ -17,15 +17,14 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.jdbc.operators.jdbc import JdbcOperator -class TestJdbcOperator(unittest.TestCase): - def setUp(self): +class TestJdbcOperator: + def setup_method(self): self.kwargs = dict(sql="sql", task_id="test_jdbc_operator", dag=None) @patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") diff --git a/tests/providers/jenkins/hooks/test_jenkins.py b/tests/providers/jenkins/hooks/test_jenkins.py index 53895f8572856..0ad0ba7603304 100644 --- a/tests/providers/jenkins/hooks/test_jenkins.py +++ b/tests/providers/jenkins/hooks/test_jenkins.py @@ -17,15 +17,14 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.providers.jenkins.hooks.jenkins import JenkinsHook -class TestJenkinsHook(unittest.TestCase): +class TestJenkinsHook: @mock.patch("airflow.hooks.base.BaseHook.get_connection") def test_client_created_default_http(self, get_connection_mock): """tests `init` method to validate http client creation when all parameters are passed""" @@ -69,12 +68,12 @@ def test_client_created_default_https(self, get_connection_mock): assert hook.jenkins_server is not None assert hook.jenkins_server.server == complete_url - @parameterized.expand([(True,), (False,)]) + @pytest.mark.parametrize("param_building", [True, False]) @mock.patch("airflow.hooks.base.BaseHook.get_connection") @mock.patch("jenkins.Jenkins.get_job_info") @mock.patch("jenkins.Jenkins.get_build_info") def test_get_build_building_state( - self, param_building, mock_get_build_info, mock_get_job_info, get_connection_mock + self, mock_get_build_info, mock_get_job_info, get_connection_mock, param_building ): mock_get_build_info.return_value = {"building": param_building} diff --git a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py index d4c328ea44e4d..53a32def164e0 100644 --- a/tests/providers/jenkins/operators/test_jenkins_job_trigger.py +++ b/tests/providers/jenkins/operators/test_jenkins_job_trigger.py @@ -17,36 +17,25 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch import jenkins import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.providers.jenkins.hooks.jenkins import JenkinsHook from airflow.providers.jenkins.operators.jenkins_job_trigger import JenkinsJobTriggerOperator +TEST_PARAMETERS = ( + pytest.param({"a_param": "blip", "another_param": "42"}, id="dict params"), + pytest.param('{"second_param": "beep", "third_param": "153"}', id="string params"), + pytest.param(["final_one", "bop", "real_final", "eggs"], id="list params"), +) -class TestJenkinsOperator(unittest.TestCase): - @parameterized.expand( - [ - ( - "dict params", - {"a_param": "blip", "another_param": "42"}, - ), - ( - "string params", - '{"second_param": "beep", "third_param": "153"}', - ), - ( - "list params", - ["final_one", "bop", "real_final", "eggs"], - ), - ] - ) - def test_execute(self, _, parameters): + +class TestJenkinsOperator: + @pytest.mark.parametrize("parameters", TEST_PARAMETERS) + def test_execute(self, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth="secret") jenkins_mock.get_build_info.return_value = { "result": "SUCCESS", @@ -80,23 +69,8 @@ def test_execute(self, _, parameters): assert jenkins_mock.get_build_info.call_count == 1 jenkins_mock.get_build_info.assert_called_once_with(name="a_job_on_jenkins", number="1") - @parameterized.expand( - [ - ( - "dict params", - {"a_param": "blip", "another_param": "42"}, - ), - ( - "string params", - '{"second_param": "beep", "third_param": "153"}', - ), - ( - "list params", - ["final_one", "bop", "real_final", "eggs"], - ), - ] - ) - def test_execute_job_polling_loop(self, _, parameters): + @pytest.mark.parametrize("parameters", TEST_PARAMETERS) + def test_execute_job_polling_loop(self, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth="secret") jenkins_mock.get_job_info.return_value = {"nextBuildNumber": "1"} jenkins_mock.get_build_info.side_effect = [ @@ -129,23 +103,8 @@ def test_execute_job_polling_loop(self, _, parameters): operator.execute(None) assert jenkins_mock.get_build_info.call_count == 2 - @parameterized.expand( - [ - ( - "dict params", - {"a_param": "blip", "another_param": "42"}, - ), - ( - "string params", - '{"second_param": "beep", "third_param": "153"}', - ), - ( - "list params", - ["final_one", "bop", "real_final", "eggs"], - ), - ] - ) - def test_execute_job_failure(self, _, parameters): + @pytest.mark.parametrize("parameters", TEST_PARAMETERS) + def test_execute_job_failure(self, parameters): jenkins_mock = Mock(spec=jenkins.Jenkins, auth="secret") jenkins_mock.get_job_info.return_value = {"nextBuildNumber": "1"} jenkins_mock.get_build_info.return_value = { @@ -178,7 +137,8 @@ def test_execute_job_failure(self, _, parameters): with pytest.raises(AirflowException): operator.execute(None) - @parameterized.expand( + @pytest.mark.parametrize( + "state, allowed_jenkins_states", [ ( "SUCCESS", @@ -196,7 +156,7 @@ def test_execute_job_failure(self, _, parameters): "SUCCESS", None, ), - ] + ], ) def test_allowed_jenkins_states(self, state, allowed_jenkins_states): jenkins_mock = Mock(spec=jenkins.Jenkins, auth="secret") @@ -233,7 +193,8 @@ def test_allowed_jenkins_states(self, state, allowed_jenkins_states): except AirflowException: pytest.fail(f"Job failed with state={state} while allowed states={allowed_jenkins_states}") - @parameterized.expand( + @pytest.mark.parametrize( + "state, allowed_jenkins_states", [ ( "FAILURE", @@ -255,7 +216,7 @@ def test_allowed_jenkins_states(self, state, allowed_jenkins_states): "UNSTABLE", None, ), - ] + ], ) def test_allowed_jenkins_states_failure(self, state, allowed_jenkins_states): jenkins_mock = Mock(spec=jenkins.Jenkins, auth="secret") diff --git a/tests/providers/jenkins/sensors/test_jenkins.py b/tests/providers/jenkins/sensors/test_jenkins.py index a3e9f0b5a4f58..8f69c41510d85 100644 --- a/tests/providers/jenkins/sensors/test_jenkins.py +++ b/tests/providers/jenkins/sensors/test_jenkins.py @@ -17,17 +17,17 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import MagicMock, patch -from parameterized import parameterized +import pytest from airflow.providers.jenkins.hooks.jenkins import JenkinsHook from airflow.providers.jenkins.sensors.jenkins import JenkinsBuildSensor -class TestJenkinsBuildSensor(unittest.TestCase): - @parameterized.expand( +class TestJenkinsBuildSensor: + @pytest.mark.parametrize( + "build_number, build_state", [ ( 1, @@ -41,10 +41,10 @@ class TestJenkinsBuildSensor(unittest.TestCase): 3, True, ), - ] + ], ) @patch("jenkins.Jenkins") - def test_poke(self, build_number, build_state, mock_jenkins): + def test_poke(self, mock_jenkins, build_number, build_state): target_build_number = build_number if build_number else 10 jenkins_mock = MagicMock() diff --git a/tests/providers/mongo/hooks/test_mongo.py b/tests/providers/mongo/hooks/test_mongo.py index e366a7d2c8f7d..0a106eaad3439 100644 --- a/tests/providers/mongo/hooks/test_mongo.py +++ b/tests/providers/mongo/hooks/test_mongo.py @@ -18,10 +18,10 @@ from __future__ import annotations import importlib -import unittest from types import ModuleType import pymongo +import pytest from airflow.models import Connection from airflow.providers.mongo.hooks.mongo import MongoHook @@ -48,8 +48,9 @@ def get_collection(self, mock_collection, mongo_db=None): return mock_collection -class TestMongoHook(unittest.TestCase): - def setUp(self): +@pytest.mark.skipif(mongomock is None, reason="mongomock package not present") +class TestMongoHook: + def setup_method(self): self.hook = MongoHookTest(conn_id="mongo_default", mongo_db="default") self.conn = self.hook.get_conn() db.merge_conn( @@ -57,22 +58,19 @@ def setUp(self): conn_id="mongo_default_with_srv", conn_type="mongo", host="mongo", - port="27017", + port=27017, extra='{"srv": true}', ) ) - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_get_conn(self): assert self.hook.connection.port == 27017 assert isinstance(self.conn, pymongo.MongoClient) - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_srv(self): hook = MongoHook(conn_id="mongo_default_with_srv") assert hook.uri.startswith("mongodb+srv://") - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_insert_one(self): collection = mongomock.MongoClient().db.collection obj = {"test_insert_one": "test_value"} @@ -82,7 +80,6 @@ def test_insert_one(self): assert obj == result_obj - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_insert_many(self): collection = mongomock.MongoClient().db.collection objs = [{"test_insert_many_1": "test_value"}, {"test_insert_many_2": "test_value"}] @@ -92,7 +89,6 @@ def test_insert_many(self): result_objs = list(collection.find()) assert len(result_objs) == 2 - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_update_one(self): collection = mongomock.MongoClient().db.collection obj = {"_id": "1", "field": 0} @@ -106,7 +102,6 @@ def test_update_one(self): result_obj = collection.find_one(filter="1") assert 123 == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_update_one_with_upsert(self): collection = mongomock.MongoClient().db.collection @@ -118,7 +113,6 @@ def test_update_one_with_upsert(self): result_obj = collection.find_one(filter="1") assert 123 == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_update_many(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": 0} @@ -136,7 +130,6 @@ def test_update_many(self): result_obj = collection.find_one(filter="2") assert 123 == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_update_many_with_upsert(self): collection = mongomock.MongoClient().db.collection @@ -148,7 +141,6 @@ def test_update_many_with_upsert(self): result_obj = collection.find_one(filter="1") assert 123 == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_replace_one(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": "test_value_1"} @@ -165,7 +157,6 @@ def test_replace_one(self): result_obj = collection.find_one(filter="2") assert "test_value_2" == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_replace_one_with_filter(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": "test_value_1"} @@ -182,7 +173,6 @@ def test_replace_one_with_filter(self): result_obj = collection.find_one(filter="2") assert "test_value_2" == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_replace_one_with_upsert(self): collection = mongomock.MongoClient().db.collection @@ -192,7 +182,6 @@ def test_replace_one_with_upsert(self): result_obj = collection.find_one(filter="1") assert "test_value_1" == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_replace_many(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": "test_value_1"} @@ -209,7 +198,6 @@ def test_replace_many(self): result_obj = collection.find_one(filter="2") assert "test_value_2_updated" == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_replace_many_with_upsert(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": "test_value_1"} @@ -223,7 +211,6 @@ def test_replace_many_with_upsert(self): result_obj = collection.find_one(filter="2") assert "test_value_2" == result_obj["field"] - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_delete_one(self): collection = mongomock.MongoClient().db.collection obj = {"_id": "1"} @@ -233,7 +220,6 @@ def test_delete_one(self): assert 0 == collection.count() - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_delete_many(self): collection = mongomock.MongoClient().db.collection obj1 = {"_id": "1", "field": "value"} @@ -244,7 +230,6 @@ def test_delete_many(self): assert 0 == collection.count() - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_find_one(self): collection = mongomock.MongoClient().db.collection obj = {"test_find_one": "test_value"} @@ -254,7 +239,6 @@ def test_find_one(self): result_obj = {result: result_obj[result] for result in result_obj} assert obj == result_obj - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_find_many(self): collection = mongomock.MongoClient().db.collection objs = [{"_id": 1, "test_find_many_1": "test_value"}, {"_id": 2, "test_find_many_2": "test_value"}] @@ -264,7 +248,6 @@ def test_find_many(self): assert len(list(result_objs)) > 1 - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_find_many_with_projection(self): collection = mongomock.MongoClient().db.collection objs = [ @@ -277,10 +260,8 @@ def test_find_many_with_projection(self): result_objs = self.hook.find( mongo_collection=collection, query={}, projection=projection, find_one=False ) + assert "_id" not in result_objs[0] - self.assertRaises(KeyError, lambda x: x[0]["_id"], result_objs) - - @unittest.skipIf(mongomock is None, "mongomock package not present") def test_aggregate(self): collection = mongomock.MongoClient().db.collection objs = [ @@ -296,11 +277,12 @@ def test_aggregate(self): results = self.hook.aggregate(collection, aggregate_query) assert len(list(results)) == 2 - def test_context_manager(self): - with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook: - ctx_hook.get_conn() - assert isinstance(ctx_hook, MongoHook) - assert ctx_hook.client is not None +def test_context_manager(): + with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook: + ctx_hook.get_conn() + + assert isinstance(ctx_hook, MongoHook) + assert ctx_hook.client is not None - assert ctx_hook.client is None + assert ctx_hook.client is None diff --git a/tests/providers/mongo/sensors/test_mongo.py b/tests/providers/mongo/sensors/test_mongo.py index dafc9942a1885..98eaec52738f7 100644 --- a/tests/providers/mongo/sensors/test_mongo.py +++ b/tests/providers/mongo/sensors/test_mongo.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.models import Connection @@ -31,10 +29,10 @@ @pytest.mark.integration("mongo") -class TestMongoSensor(unittest.TestCase): - def setUp(self): +class TestMongoSensor: + def setup_method(self): db.merge_conn( - Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port="27017", schema="test") + Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", port=27017, schema="test") ) args = {"owner": "airflow", "start_date": DEFAULT_DATE} diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py index d7270dafdd8ad..3a2724cd8f406 100644 --- a/tests/providers/mysql/hooks/test_mysql.py +++ b/tests/providers/mysql/hooks/test_mysql.py @@ -19,14 +19,12 @@ import json import os -import unittest import uuid from contextlib import closing from unittest import mock import MySQLdb.cursors import pytest -from parameterized import parameterized from airflow.models import Connection from airflow.models.dag import DAG @@ -36,10 +34,8 @@ SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key": "/tmp/client-key.pem"} -class TestMySqlHookConn(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestMySqlHookConn: + def setup_method(self): self.connection = Connection( conn_type="mysql", login="login", @@ -169,10 +165,8 @@ def test_get_conn_rds_iam(self, mock_client, mock_connect): ) -class TestMySqlHookConnMySqlConnectorPython(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestMySqlHookConnMySqlConnectorPython: + def setup_method(self): self.connection = Connection( login="login", password="password", @@ -232,10 +226,8 @@ def autocommit(self, autocommit): self._autocommit = autocommit -class TestMySqlHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestMySqlHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur @@ -249,7 +241,7 @@ def get_conn(self): self.db_hook = SubMySqlHook() - @parameterized.expand([(True,), (False,)]) + @pytest.mark.parametrize("autocommit", [True, False]) def test_set_autocommit_mysql_connector(self, autocommit): conn = MockMySQLConnectorConnection() self.db_hook.set_autocommit(conn, autocommit) @@ -369,25 +361,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): @pytest.mark.backend("mysql") -class TestMySql(unittest.TestCase): - def setUp(self): +class TestMySql: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - def tearDown(self): + def teardown_method(self): drop_tables = {"test_mysql_to_mysql", "test_airflow"} with closing(MySqlHook().get_conn()) as conn: with closing(conn.cursor()) as cursor: for table in drop_tables: cursor.execute(f"DROP TABLE IF EXISTS {table}") - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) @mock.patch.dict( "os.environ", { @@ -420,12 +407,7 @@ def test_mysql_hook_test_bulk_load(self, client): results = tuple(result[0] for result in cursor.fetchall()) assert sorted(results) == sorted(records) - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) def test_mysql_hook_test_bulk_dump(self, client): with MySqlContext(client): hook = MySqlHook("airflow_db") @@ -442,14 +424,9 @@ def test_mysql_hook_test_bulk_dump(self, client): else: raise pytest.skip("Skip test_mysql_hook_test_bulk_load since file output is not permitted") - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) @mock.patch("airflow.providers.mysql.hooks.mysql.MySqlHook.get_conn") - def test_mysql_hook_test_bulk_dump_mock(self, client, mock_get_conn): + def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client): with MySqlContext(client): mock_execute = mock.MagicMock() mock_get_conn.return_value.cursor.return_value.execute = mock_execute diff --git a/tests/providers/mysql/operators/test_mysql.py b/tests/providers/mysql/operators/test_mysql.py index a272f1f3d489b..379459985aa6f 100644 --- a/tests/providers/mysql/operators/test_mysql.py +++ b/tests/providers/mysql/operators/test_mysql.py @@ -18,12 +18,10 @@ from __future__ import annotations import os -import unittest from contextlib import closing from tempfile import NamedTemporaryFile import pytest -from parameterized import parameterized from airflow.models.dag import DAG from airflow.providers.mysql.hooks.mysql import MySqlHook @@ -38,25 +36,20 @@ @pytest.mark.backend("mysql") -class TestMySql(unittest.TestCase): - def setUp(self): +class TestMySql: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - def tearDown(self): + def teardown_method(self): drop_tables = {"test_mysql_to_mysql", "test_airflow"} with closing(MySqlHook().get_conn()) as conn: with closing(conn.cursor()) as cursor: for table in drop_tables: cursor.execute(f"DROP TABLE IF EXISTS {table}") - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) def test_mysql_operator_test(self, client): with MySqlContext(client): sql = """ @@ -67,12 +60,7 @@ def test_mysql_operator_test(self, client): op = MySqlOperator(task_id="basic_mysql", sql=sql, dag=self.dag) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) def test_mysql_operator_test_multi(self, client): with MySqlContext(client): sql = [ @@ -87,12 +75,7 @@ def test_mysql_operator_test_multi(self, client): ) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @parameterized.expand( - [ - ("mysqlclient",), - ("mysql-connector-python",), - ] - ) + @pytest.mark.parametrize("client", ["mysqlclient", "mysql-connector-python"]) def test_overwrite_schema(self, client): """ Verifies option to overwrite connection schema diff --git a/tests/providers/mysql/transfers/test_s3_to_mysql.py b/tests/providers/mysql/transfers/test_s3_to_mysql.py index 743225d3c06f7..50e62bca9ce30 100644 --- a/tests/providers/mysql/transfers/test_s3_to_mysql.py +++ b/tests/providers/mysql/transfers/test_s3_to_mysql.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch import pytest @@ -28,8 +27,8 @@ from airflow.utils.session import create_session -class TestS3ToMySqlTransfer(unittest.TestCase): - def setUp(self): +class TestS3ToMySqlTransfer: + def setup_method(self): configuration.conf.load_test_config() db.merge_conn( @@ -99,7 +98,7 @@ def test_execute_exception(self, mock_remove, mock_bulk_load_custom, mock_downlo ) mock_remove.assert_called_once_with(mock_download_file.return_value) - def tearDown(self): + def teardown_method(self): with create_session() as session: ( session.query(models.Connection) diff --git a/tests/providers/mysql/transfers/test_vertica_to_mysql.py b/tests/providers/mysql/transfers/test_vertica_to_mysql.py index 9af6203866fdb..e13c5d05781c9 100644 --- a/tests/providers/mysql/transfers/test_vertica_to_mysql.py +++ b/tests/providers/mysql/transfers/test_vertica_to_mysql.py @@ -18,7 +18,6 @@ from __future__ import annotations import datetime -import unittest from unittest import mock from airflow.models.dag import DAG @@ -40,8 +39,8 @@ def mock_get_conn(): return conn_mock -class TestVerticaToMySqlTransfer(unittest.TestCase): - def setUp(self): +class TestVerticaToMySqlTransfer: + def setup_method(self): args = {"owner": "airflow", "start_date": datetime.datetime(2017, 1, 1)} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/neo4j/hooks/test_neo4j.py b/tests/providers/neo4j/hooks/test_neo4j.py index 1ac536335326f..2191b326b1ded 100644 --- a/tests/providers/neo4j/hooks/test_neo4j.py +++ b/tests/providers/neo4j/hooks/test_neo4j.py @@ -16,23 +16,23 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock -from parameterized import parameterized +import pytest from airflow.models import Connection from airflow.providers.neo4j.hooks.neo4j import Neo4jHook -class TestNeo4jHookConn(unittest.TestCase): - @parameterized.expand( +class TestNeo4jHookConn: + @pytest.mark.parametrize( + "conn_extra, expected_uri", [ - [{}, "bolt://host:7687"], - [{"bolt_scheme": True}, "bolt://host:7687"], - [{"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"], - [{"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"], - ] + ({}, "bolt://host:7687"), + ({"bolt_scheme": True}, "bolt://host:7687"), + ({"certs_self_signed": True, "bolt_scheme": True}, "bolt+ssc://host:7687"), + ({"certs_trusted_ca": True, "bolt_scheme": True}, "bolt+s://host:7687"), + ], ) def test_get_uri_neo4j_scheme(self, conn_extra, expected_uri): connection = Connection( @@ -75,10 +75,7 @@ def test_run_with_schema(self, mock_graph_database): ] ) session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value - self.assertEqual( - session.run.return_value.data.return_value, - op_result, - ) + assert op_result == session.run.return_value.data.return_value @mock.patch("airflow.providers.neo4j.hooks.neo4j.GraphDatabase") def test_run_without_schema(self, mock_graph_database): @@ -103,7 +100,4 @@ def test_run_without_schema(self, mock_graph_database): ] ) session = mock_graph_database.driver.return_value.session.return_value.__enter__.return_value - self.assertEqual( - session.run.return_value.data.return_value, - op_result, - ) + assert op_result == session.run.return_value.data.return_value diff --git a/tests/providers/neo4j/operators/test_neo4j.py b/tests/providers/neo4j/operators/test_neo4j.py index a38960ebf27ea..2cf95ec3374f8 100644 --- a/tests/providers/neo4j/operators/test_neo4j.py +++ b/tests/providers/neo4j/operators/test_neo4j.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.neo4j.operators.neo4j import Neo4jOperator @@ -28,7 +27,7 @@ TEST_DAG_ID = "unit_test_dag" -class TestNeo4jOperator(unittest.TestCase): +class TestNeo4jOperator: @mock.patch("airflow.providers.neo4j.operators.neo4j.Neo4jHook") def test_neo4j_operator_test(self, mock_hook): diff --git a/tests/providers/openfaas/hooks/test_openfaas.py b/tests/providers/openfaas/hooks/test_openfaas.py index a5a8a99e80ae3..17172cc383425 100644 --- a/tests/providers/openfaas/hooks/test_openfaas.py +++ b/tests/providers/openfaas/hooks/test_openfaas.py @@ -17,11 +17,9 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -import requests_mock from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook @@ -31,21 +29,20 @@ FUNCTION_NAME = "function_name" -class TestOpenFaasHook(unittest.TestCase): +class TestOpenFaasHook: GET_FUNCTION = "/system/function/" INVOKE_ASYNC_FUNCTION = "/async-function/" INVOKE_FUNCTION = "/function/" DEPLOY_FUNCTION = "/system/functions" UPDATE_FUNCTION = "/system/functions" - def setUp(self): + def setup_method(self): self.hook = OpenFaasHook(function_name=FUNCTION_NAME) self.mock_response = {"ans": "a"} @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_is_function_exist_false(self, mock_get_connection, m): - m.get( + def test_is_function_exist_false(self, mock_get_connection, requests_mock): + requests_mock.get( "http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=404, @@ -57,9 +54,8 @@ def test_is_function_exist_false(self, mock_get_connection, m): assert not does_function_exist @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_is_function_exist_true(self, mock_get_connection, m): - m.get( + def test_is_function_exist_true(self, mock_get_connection, requests_mock): + requests_mock.get( "http://open-faas.io" + self.GET_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=202, @@ -71,18 +67,20 @@ def test_is_function_exist_true(self, mock_get_connection, m): assert does_function_exist @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_update_function_true(self, mock_get_connection, m): - m.put("http://open-faas.io" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=202) + def test_update_function_true(self, mock_get_connection, requests_mock): + requests_mock.put( + "http://open-faas.io" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=202 + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection self.hook.update_function({}) # returns None @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_update_function_false(self, mock_get_connection, m): - m.put("http://open-faas.io" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=400) + def test_update_function_false(self, mock_get_connection, requests_mock): + requests_mock.put( + "http://open-faas.io" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=400 + ) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection @@ -91,9 +89,8 @@ def test_update_function_false(self, mock_get_connection, m): assert "failed to update " + FUNCTION_NAME in str(ctx.value) @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_invoke_function_false(self, mock_get_connection, m): - m.post( + def test_invoke_function_false(self, mock_get_connection, requests_mock): + requests_mock.post( "http://open-faas.io" + self.INVOKE_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=400, @@ -106,9 +103,8 @@ def test_invoke_function_false(self, mock_get_connection, m): assert "failed to invoke function" in str(ctx.value) @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_invoke_function_true(self, mock_get_connection, m): - m.post( + def test_invoke_function_true(self, mock_get_connection, requests_mock): + requests_mock.post( "http://open-faas.io" + self.INVOKE_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=200, @@ -118,9 +114,8 @@ def test_invoke_function_true(self, mock_get_connection, m): assert self.hook.invoke_function({}) is None @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_invoke_async_function_false(self, mock_get_connection, m): - m.post( + def test_invoke_async_function_false(self, mock_get_connection, requests_mock): + requests_mock.post( "http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=400, @@ -133,9 +128,8 @@ def test_invoke_async_function_false(self, mock_get_connection, m): assert "failed to invoke function" in str(ctx.value) @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_invoke_async_function_true(self, mock_get_connection, m): - m.post( + def test_invoke_async_function_true(self, mock_get_connection, requests_mock): + requests_mock.post( "http://open-faas.io" + self.INVOKE_ASYNC_FUNCTION + FUNCTION_NAME, json=self.mock_response, status_code=202, @@ -145,17 +139,17 @@ def test_invoke_async_function_true(self, mock_get_connection, m): assert self.hook.invoke_async_function({}) is None @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_deploy_function_function_already_exist(self, mock_get_connection, m): - m.put("http://open-faas.io/" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=202) + def test_deploy_function_function_already_exist(self, mock_get_connection, requests_mock): + requests_mock.put( + "http://open-faas.io/" + self.UPDATE_FUNCTION, json=self.mock_response, status_code=202 + ) mock_connection = Connection(host="http://open-faas.io/") mock_get_connection.return_value = mock_connection assert self.hook.deploy_function(True, {}) is None @mock.patch.object(BaseHook, "get_connection") - @requests_mock.mock() - def test_deploy_function_function_not_exist(self, mock_get_connection, m): - m.post("http://open-faas.io" + self.DEPLOY_FUNCTION, json={}, status_code=202) + def test_deploy_function_function_not_exist(self, mock_get_connection, requests_mock): + requests_mock.post("http://open-faas.io" + self.DEPLOY_FUNCTION, json={}, status_code=202) mock_connection = Connection(host="http://open-faas.io") mock_get_connection.return_value = mock_connection assert self.hook.deploy_function(False, {}) is None diff --git a/tests/providers/opsgenie/hooks/test_opsgenie.py b/tests/providers/opsgenie/hooks/test_opsgenie.py index f2d51476349c6..8b3a518356f12 100644 --- a/tests/providers/opsgenie/hooks/test_opsgenie.py +++ b/tests/providers/opsgenie/hooks/test_opsgenie.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -29,7 +28,7 @@ from airflow.utils import db -class TestOpsgenieAlertHook(unittest.TestCase): +class TestOpsgenieAlertHook: conn_id = "opsgenie_conn_id_test" opsgenie_alert_endpoint = "https://api.opsgenie.com/v2/alerts" _create_alert_payload = { @@ -67,7 +66,7 @@ class TestOpsgenieAlertHook(unittest.TestCase): "request_id": "43a29c5c-3dbf-4fa4-9c26-f4f71023e120", } - def setUp(self): + def setup_method(self): db.merge_conn( Connection( conn_id=self.conn_id, diff --git a/tests/providers/opsgenie/operators/test_opsgenie.py b/tests/providers/opsgenie/operators/test_opsgenie.py index ee06dfc1ca1ff..0194660323f7f 100644 --- a/tests/providers/opsgenie/operators/test_opsgenie.py +++ b/tests/providers/opsgenie/operators/test_opsgenie.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.models.dag import DAG @@ -31,7 +30,7 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -class TestOpsgenieCreateAlertOperator(unittest.TestCase): +class TestOpsgenieCreateAlertOperator: _config = { "message": "An example alert message", "alias": "Life is too short for no alias", @@ -78,7 +77,7 @@ class TestOpsgenieCreateAlertOperator(unittest.TestCase): "note": _config["note"], } - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) @@ -111,7 +110,7 @@ def test_properties(self): assert self._config["note"] == operator.note -class TestOpsgenieCloseAlertOperator(unittest.TestCase): +class TestOpsgenieCloseAlertOperator: _config = {"user": "example_user", "note": "my_closing_note", "source": "some_source"} expected_payload_dict = { "user": _config["user"], @@ -119,7 +118,7 @@ class TestOpsgenieCloseAlertOperator(unittest.TestCase): "source": _config["source"], } - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) @@ -145,8 +144,8 @@ def test_properties(self): assert self._config["source"] == operator.source -class TestOpsgenieDeleteAlertOperator(unittest.TestCase): - def setUp(self): +class TestOpsgenieDeleteAlertOperator: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 24c492d10c28c..9f49618b45435 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -18,27 +18,19 @@ from __future__ import annotations import json -import unittest from datetime import datetime from unittest import mock import numpy +import oracledb import pytest from airflow.models import Connection from airflow.providers.oracle.hooks.oracle import OracleHook -try: - import oracledb -except ImportError: - oracledb = None # type: ignore - - -@unittest.skipIf(oracledb is None, "oracledb package not present") -class TestOracleHookConn(unittest.TestCase): - def setUp(self): - super().setUp() +class TestOracleHookConn: + def setup_method(self): self.connection = Connection( login="login", password="password", host="host", schema="schema", port=1521 ) @@ -265,11 +257,8 @@ def test_type_checking_thick_mode_config_dir(self): self.db_hook.get_conn() -@unittest.skipIf(oracledb is None, "oracledb package not present") -class TestOracleHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestOracleHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur diff --git a/tests/providers/oracle/operators/test_oracle.py b/tests/providers/oracle/operators/test_oracle.py index ba5e82df3999b..2e5a8e10e63dc 100644 --- a/tests/providers/oracle/operators/test_oracle.py +++ b/tests/providers/oracle/operators/test_oracle.py @@ -16,15 +16,16 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock +import pytest + from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.oracle.hooks.oracle import OracleHook from airflow.providers.oracle.operators.oracle import OracleOperator, OracleStoredProcedureOperator -class TestOracleOperator(unittest.TestCase): +class TestOracleOperator: @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") def test_execute(self, mock_get_db_hook): sql = "SELECT * FROM test_table" @@ -34,13 +35,14 @@ def test_execute(self, mock_get_db_hook): context = "test_context" task_id = "test_task_id" - operator = OracleOperator( - sql=sql, - oracle_conn_id=oracle_conn_id, - parameters=parameters, - autocommit=autocommit, - task_id=task_id, - ) + with pytest.warns(DeprecationWarning, match="This class is deprecated.*"): + operator = OracleOperator( + sql=sql, + oracle_conn_id=oracle_conn_id, + parameters=parameters, + autocommit=autocommit, + task_id=task_id, + ) operator.execute(context=context) mock_get_db_hook.return_value.run.assert_called_once_with( sql=sql, @@ -52,7 +54,7 @@ def test_execute(self, mock_get_db_hook): ) -class TestOracleStoredProcedureOperator(unittest.TestCase): +class TestOracleStoredProcedureOperator: @mock.patch.object(OracleHook, "run", autospec=OracleHook.run) def test_execute(self, mock_run): procedure = "test" diff --git a/tests/providers/oracle/transfers/test_oracle_to_oracle.py b/tests/providers/oracle/transfers/test_oracle_to_oracle.py index 587238620aebc..e2e66706da852 100644 --- a/tests/providers/oracle/transfers/test_oracle_to_oracle.py +++ b/tests/providers/oracle/transfers/test_oracle_to_oracle.py @@ -17,16 +17,14 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import MagicMock from airflow.providers.oracle.transfers.oracle_to_oracle import OracleToOracleOperator -class TestOracleToOracleTransfer(unittest.TestCase): - @staticmethod - def test_execute(): +class TestOracleToOracleTransfer: + def test_execute(self): oracle_destination_conn_id = "oracle_destination_conn_id" destination_table = "destination_table" oracle_source_conn_id = "oracle_source_conn_id" diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index dc072b9e3f2d8..2ab23280a73be 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import patch from airflow.models import DAG, DagRun, TaskInstance @@ -27,7 +26,7 @@ DEFAULT_DATE = timezone.datetime(2021, 1, 1) -class TestPapermillOperator(unittest.TestCase): +class TestPapermillOperator: @patch("airflow.providers.papermill.operators.papermill.pm") def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index bc0c1974381d8..70ca7823835ac 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from tempfile import NamedTemporaryFile from unittest import mock @@ -31,8 +30,7 @@ class TestPostgresHookConn: - @pytest.fixture(autouse=True) - def setup(self): + def setup_method(self): self.connection = Connection(login="login", password="password", host="host", schema="database") class UnitTestPostgresHook(PostgresHook): @@ -258,14 +256,11 @@ def test_schema_kwarg_database_kwarg_compatibility(self): assert hook.database == database -class TestPostgresHook(unittest.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.table = "test_postgres_hook_table" - - def setUp(self): - super().setUp() +@pytest.mark.backend("postgres") +class TestPostgresHook: + table = "test_postgres_hook_table" + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = conn = mock.MagicMock() self.conn.cursor.return_value = self.cur @@ -278,14 +273,11 @@ def get_conn(self): self.db_hook = UnitTestPostgresHook() - def tearDown(self): - super().tearDown() - + def teardown_method(self): with PostgresHook().get_conn() as conn: with conn.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table}") - @pytest.mark.backend("postgres") def test_copy_expert(self): open_mock = mock.mock_open(read_data='{"some": "json"}') with mock.patch("airflow.providers.postgres.hooks.postgres.open", open_mock): @@ -302,7 +294,6 @@ def test_copy_expert(self): self.cur.copy_expert.assert_called_once_with(statement, open_mock.return_value) assert open_mock.call_args[0] == (filename, "r+") - @pytest.mark.backend("postgres") def test_bulk_load(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] @@ -322,7 +313,6 @@ def test_bulk_load(self): assert sorted(input_data) == sorted(results) - @pytest.mark.backend("postgres") def test_bulk_dump(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] @@ -341,7 +331,6 @@ def test_bulk_dump(self): assert sorted(input_data) == sorted(results) - @pytest.mark.backend("postgres") def test_insert_rows(self): table = "table" rows = [("hello",), ("world",)] @@ -358,7 +347,6 @@ def test_insert_rows(self): for row in rows: self.cur.execute.assert_any_call(sql, row) - @pytest.mark.backend("postgres") def test_insert_rows_replace(self): table = "table" rows = [ @@ -388,7 +376,6 @@ def test_insert_rows_replace(self): for row in rows: self.cur.execute.assert_any_call(sql, row) - @pytest.mark.backend("postgres") def test_insert_rows_replace_missing_target_field_arg(self): table = "table" rows = [ @@ -407,7 +394,6 @@ def test_insert_rows_replace_missing_target_field_arg(self): assert str(ctx.value) == "PostgreSQL ON CONFLICT upsert syntax requires column names" - @pytest.mark.backend("postgres") def test_insert_rows_replace_missing_replace_index_arg(self): table = "table" rows = [ @@ -426,7 +412,6 @@ def test_insert_rows_replace_missing_replace_index_arg(self): assert str(ctx.value) == "PostgreSQL ON CONFLICT upsert syntax requires an unique index" - @pytest.mark.backend("postgres") def test_insert_rows_replace_all_index(self): table = "table" rows = [ @@ -456,7 +441,6 @@ def test_insert_rows_replace_all_index(self): for row in rows: self.cur.execute.assert_any_call(sql, row) - @pytest.mark.backend("postgres") def test_rowcount(self): hook = PostgresHook() input_data = ["foo", "bar", "baz"] diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index 394cfc2618fd6..4b615917ea28e 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.models.dag import DAG @@ -32,13 +30,13 @@ @pytest.mark.backend("postgres") -class TestPostgres(unittest.TestCase): - def setUp(self): +class TestPostgres: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - def tearDown(self): + def teardown_method(self): tables_to_drop = ["test_postgres_to_postgres", "test_airflow"] from airflow.providers.postgres.hooks.postgres import PostgresHook diff --git a/tests/providers/presto/hooks/test_presto.py b/tests/providers/presto/hooks/test_presto.py index e1b1365f5360b..fa3062a8b2785 100644 --- a/tests/providers/presto/hooks/test_presto.py +++ b/tests/providers/presto/hooks/test_presto.py @@ -19,12 +19,10 @@ import json import re -import unittest from unittest import mock from unittest.mock import patch import pytest -from parameterized import parameterized from prestodb.transaction import IsolationLevel from airflow import AirflowException @@ -56,7 +54,7 @@ def test_generate_airflow_presto_client_info_header(): assert generate_presto_client_info() == expected -class TestPrestoHookConn(unittest.TestCase): +class TestPrestoHookConn: @patch("airflow.providers.presto.hooks.presto.prestodb.auth.BasicAuthentication") @patch("airflow.providers.presto.hooks.presto.prestodb.dbapi.connect") @patch("airflow.providers.presto.hooks.presto.PrestoHook.get_connection") @@ -190,14 +188,15 @@ def test_http_headers( mock_basic_auth.assert_called_once_with("login", "password") assert mock_connect.return_value == conn - @parameterized.expand( + @pytest.mark.parametrize( + "current_verify, expected_verify", [ ("False", False), ("false", False), ("true", True), ("true", True), ("/tmp/cert.crt", "/tmp/cert.crt"), - ] + ], ) def test_get_conn_verify(self, current_verify, expected_verify): patcher_connect = patch("airflow.providers.presto.hooks.presto.prestodb.dbapi.connect") @@ -215,10 +214,8 @@ def test_get_conn_verify(self, current_verify, expected_verify): assert mock_connect.return_value == conn -class TestPrestoHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestPrestoHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur diff --git a/tests/providers/presto/transfers/test_gcs_presto.py b/tests/providers/presto/transfers/test_gcs_presto.py index 247bb414bc149..fe0689099bbae 100644 --- a/tests/providers/presto/transfers/test_gcs_presto.py +++ b/tests/providers/presto/transfers/test_gcs_presto.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.presto.transfers.gcs_to_presto import GCSToPrestoOperator @@ -33,7 +32,7 @@ SCHEMA_JSON = "path/to/file.json" -class TestGCSToPrestoOperator(unittest.TestCase): +class TestGCSToPrestoOperator: @mock.patch("airflow.providers.presto.transfers.gcs_to_presto.PrestoHook") @mock.patch("airflow.providers.presto.transfers.gcs_to_presto.GCSHook") @mock.patch("airflow.providers.presto.transfers.gcs_to_presto.NamedTemporaryFile") diff --git a/tests/providers/qubole/hooks/test_qubole.py b/tests/providers/qubole/hooks/test_qubole.py index aee6eb70690da..a037b8110a87e 100644 --- a/tests/providers/qubole/hooks/test_qubole.py +++ b/tests/providers/qubole/hooks/test_qubole.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from unittest import TestCase, mock +from unittest import mock from qds_sdk.commands import PrestoCommand @@ -38,7 +38,7 @@ def get_result_mock(fp, inline, delim, fetch, arguments): fp.write(bytearray(RESULTS_WITH_NO_HEADER, "utf-8")) -class TestQuboleHook(TestCase): +class TestQuboleHook: def test_add_string_to_tags(self): tags = {"dag_id", "task_id"} add_tags(tags, "string") diff --git a/tests/providers/qubole/hooks/test_qubole_check.py b/tests/providers/qubole/hooks/test_qubole_check.py index a7bedc52cc66d..8c5b27da520ab 100644 --- a/tests/providers/qubole/hooks/test_qubole_check.py +++ b/tests/providers/qubole/hooks/test_qubole_check.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import unittest - from airflow.providers.qubole.hooks.qubole_check import parse_first_row -class TestQuboleCheckHook(unittest.TestCase): +class TestQuboleCheckHook: def test_single_row_bool(self): query_result = ["true\ttrue"] record_list = parse_first_row(query_result) diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py index 3e8ccfc968dff..85b5788fda200 100644 --- a/tests/providers/qubole/operators/test_qubole_check.py +++ b/tests/providers/qubole/operators/test_qubole_check.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest import mock from unittest.mock import MagicMock @@ -95,8 +94,8 @@ def test_execute_fail(self, mock_handle_airflow_exception, operator_class, kwarg mock_handle_airflow_exception.assert_called_once() -class TestQuboleValueCheckOperator(unittest.TestCase): - def setUp(self): +class TestQuboleValueCheckOperator: + def setup_method(self): self.task_id = "test_task" self.conn_id = "default_conn" diff --git a/tests/providers/qubole/sensors/test_qubole.py b/tests/providers/qubole/sensors/test_qubole.py index b0212974613a0..02fd29ba7fc65 100644 --- a/tests/providers/qubole/sensors/test_qubole.py +++ b/tests/providers/qubole/sensors/test_qubole.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest.mock import patch @@ -35,8 +34,8 @@ DEFAULT_DATE = datetime(2017, 1, 1) -class TestQuboleSensor(unittest.TestCase): - def setUp(self): +class TestQuboleSensor: + def setup_method(self): db.merge_conn(Connection(conn_id=DEFAULT_CONN, conn_type="HTTP")) @patch("airflow.providers.qubole.sensors.qubole.QuboleFileSensor.poke") diff --git a/tests/providers/redis/hooks/test_redis.py b/tests/providers/redis/hooks/test_redis.py index efc7521d7fd88..0eb4edd76fc0a 100644 --- a/tests/providers/redis/hooks/test_redis.py +++ b/tests/providers/redis/hooks/test_redis.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -26,7 +25,7 @@ from airflow.providers.redis.hooks.redis import RedisHook -class TestRedisHook(unittest.TestCase): +class TestRedisHook: def test_get_conn(self): hook = RedisHook(redis_conn_id="redis_default") assert hook.redis is None diff --git a/tests/providers/redis/operators/test_redis_publish.py b/tests/providers/redis/operators/test_redis_publish.py index 4fca70b3a6cc6..cc1468db9ee39 100644 --- a/tests/providers/redis/operators/test_redis_publish.py +++ b/tests/providers/redis/operators/test_redis_publish.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import MagicMock import pytest @@ -31,7 +30,7 @@ @pytest.mark.integration("redis") -class TestRedisPublishOperator(unittest.TestCase): +class TestRedisPublishOperator: def setUp(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} diff --git a/tests/providers/redis/sensors/test_redis_key.py b/tests/providers/redis/sensors/test_redis_key.py index 6a338d429179f..f54e816c29a50 100644 --- a/tests/providers/redis/sensors/test_redis_key.py +++ b/tests/providers/redis/sensors/test_redis_key.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.models.dag import DAG @@ -30,8 +28,8 @@ @pytest.mark.integration("redis") -class TestRedisSensor(unittest.TestCase): - def setUp(self): +class TestRedisSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/redis/sensors/test_redis_pub_sub.py b/tests/providers/redis/sensors/test_redis_pub_sub.py index 7ef56da33efb8..5ed0c40db71bd 100644 --- a/tests/providers/redis/sensors/test_redis_pub_sub.py +++ b/tests/providers/redis/sensors/test_redis_pub_sub.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from time import sleep from unittest.mock import MagicMock, call, patch @@ -31,8 +30,8 @@ DEFAULT_DATE = timezone.datetime(2017, 1, 1) -class TestRedisPubSubSensor(unittest.TestCase): - def setUp(self): +class TestRedisPubSubSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_dag_id", default_args=args) diff --git a/tests/providers/salesforce/operators/test_salesforce_apex_rest.py b/tests/providers/salesforce/operators/test_salesforce_apex_rest.py index 6bb84d3ae09ea..822d3667be924 100644 --- a/tests/providers/salesforce/operators/test_salesforce_apex_rest.py +++ b/tests/providers/salesforce/operators/test_salesforce_apex_rest.py @@ -16,13 +16,12 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch from airflow.providers.salesforce.operators.salesforce_apex_rest import SalesforceApexRestOperator -class TestSalesforceApexRestOperator(unittest.TestCase): +class TestSalesforceApexRestOperator: """ Test class for SalesforceApexRestOperator """ diff --git a/tests/providers/salesforce/sensors/__init__.py b/tests/providers/salesforce/sensors/__init__.py deleted file mode 100644 index 13a83393a9124..0000000000000 --- a/tests/providers/salesforce/sensors/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. diff --git a/tests/providers/samba/hooks/test_samba.py b/tests/providers/samba/hooks/test_samba.py index 944122eb2148c..06ec32cb48193 100644 --- a/tests/providers/samba/hooks/test_samba.py +++ b/tests/providers/samba/hooks/test_samba.py @@ -17,12 +17,10 @@ # under the License. from __future__ import annotations -import unittest from inspect import getfullargspec from unittest import mock import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.models import Connection @@ -38,7 +36,7 @@ ) -class TestSambaHook(unittest.TestCase): +class TestSambaHook: def test_get_conn_should_fail_if_conn_id_does_not_exist(self): with pytest.raises(AirflowException): SambaHook("conn") @@ -65,7 +63,8 @@ def test_context_manager(self, get_conn_mock, register_session): # Test that the connection was disconnected upon exit. assert len(mock_connection.disconnect.mock_calls) == 1 - @parameterized.expand( + @pytest.mark.parametrize( + "name", [ "getxattr", "link", @@ -94,7 +93,7 @@ def test_context_manager(self, get_conn_mock, register_session): ], ) @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test_method(self, name, get_conn_mock): + def test_method(self, get_conn_mock, name): get_conn_mock.return_value = CONNECTION hook = SambaHook("samba_default") connection_settings = { @@ -132,14 +131,15 @@ def test_method(self, name, get_conn_mock): # We expect keyword arguments to include the connection settings. assert dict(kwargs, **connection_settings) == p_kwargs - @parameterized.expand( + @pytest.mark.parametrize( + "path, full_path", [ ("/start/path/with/slash", "//ip/share/start/path/with/slash"), ("start/path/without/slash", "//ip/share/start/path/without/slash"), ], ) @mock.patch("airflow.hooks.base.BaseHook.get_connection") - def test__join_path(self, path, full_path, get_conn_mock): + def test__join_path(self, get_conn_mock, path, full_path): get_conn_mock.return_value = CONNECTION hook = SambaHook("samba_default") assert hook._join_path(path) == full_path diff --git a/tests/providers/segment/hooks/test_segment.py b/tests/providers/segment/hooks/test_segment.py index b766caf639adc..47cc5884f5820 100644 --- a/tests/providers/segment/hooks/test_segment.py +++ b/tests/providers/segment/hooks/test_segment.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -29,10 +28,8 @@ WRITE_KEY = "foo" -class TestSegmentHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestSegmentHook: + def setup_method(self): self.conn = conn = mock.MagicMock() conn.write_key = WRITE_KEY self.expected_write_key = WRITE_KEY diff --git a/tests/providers/segment/operators/test_segment_track_event.py b/tests/providers/segment/operators/test_segment_track_event.py index b5ae582026831..bf7e4bd0d0d26 100644 --- a/tests/providers/segment/operators/test_segment_track_event.py +++ b/tests/providers/segment/operators/test_segment_track_event.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -30,9 +29,8 @@ WRITE_KEY = "foo" -class TestSegmentHook(unittest.TestCase): - def setUp(self): - super().setUp() +class TestSegmentHook: + def setup_method(self): self.conn = conn = mock.MagicMock() conn.write_key = WRITE_KEY @@ -59,7 +57,7 @@ def test_on_error(self): self.test_hook.on_error("error", ["items"]) -class TestSegmentTrackEventOperator(unittest.TestCase): +class TestSegmentTrackEventOperator: @mock.patch("airflow.providers.segment.operators.segment_track_event.SegmentHook") def test_execute(self, mock_hook): # Given diff --git a/tests/providers/sendgrid/utils/test_emailer.py b/tests/providers/sendgrid/utils/test_emailer.py index eace39557e47e..e3428e349efa7 100644 --- a/tests/providers/sendgrid/utils/test_emailer.py +++ b/tests/providers/sendgrid/utils/test_emailer.py @@ -20,15 +20,14 @@ import copy import os import tempfile -import unittest from unittest import mock from airflow.providers.sendgrid.utils.emailer import send_email -class TestSendEmailSendGrid(unittest.TestCase): +class TestSendEmailSendGrid: # Unit test for sendgrid.send_email() - def setUp(self): + def setup_method(self): self.recipients = ["foo@foo.com", "bar@bar.com"] self.subject = "sendgrid-send-email unit test" self.html_content = "Foo bar" diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py index b471dc88b053f..d2855b601fda6 100644 --- a/tests/providers/sftp/hooks/test_sftp.py +++ b/tests/providers/sftp/hooks/test_sftp.py @@ -20,13 +20,11 @@ import json import os import shutil -import unittest from io import StringIO from unittest import mock import paramiko import pytest -from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.models import Connection @@ -58,7 +56,7 @@ def generate_host_key(pkey: paramiko.PKey): TEST_KEY_FILE = "~/.ssh/id_rsa" -class TestSFTPHook(unittest.TestCase): +class TestSFTPHook: @provide_session def update_connection(self, login, session=None): connection = session.query(Connection).filter(Connection.conn_id == "sftp_default").first() @@ -72,7 +70,7 @@ def _create_additional_test_file(self, file_name): with open(os.path.join(TMP_PATH, file_name), "a") as file: file.write("Test file") - def setUp(self): + def setup_method(self): self.old_login = self.update_connection(SFTP_CONNECTION_USER) self.hook = SFTPHook() os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, SUB_DIR)) @@ -290,19 +288,21 @@ def test_key_file(self, get_connection): hook = SFTPHook() assert hook.key_file == TEST_KEY_FILE - @parameterized.expand( + @pytest.mark.parametrize( + "path, exists", [ (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS), True), (os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), True), (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS + "abc"), False), (os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS, "abc"), False), - ] + ], ) def test_path_exists(self, path, exists): result = self.hook.path_exists(path) assert result == exists - @parameterized.expand( + @pytest.mark.parametrize( + "path, prefix, delimiter, match", [ ("test/path/file.bin", None, None, True), ("test/path/file.bin", "test", None, True), @@ -315,7 +315,7 @@ def test_path_exists(self, path, exists): ("test/path/file.bin", "test//", None, False), ("test/path/file.bin", None, ".txt", False), ("test/path/file.bin", "diff", ".txt", False), - ] + ], ) def test_path_match(self, path, prefix, delimiter, match): result = self.hook._is_path_match(path=path, prefix=prefix, delimiter=delimiter) @@ -366,11 +366,11 @@ def test_deprecation_ftp_conn_id(self, mock_get_connection): connection = Connection(conn_id="ftp_default", login="login", host="host") mock_get_connection.return_value = connection # If `ftp_conn_id` is provided, it will be used but would show a deprecation warning. - with self.assertWarnsRegex(DeprecationWarning, "Parameter `ftp_conn_id` is deprecated"): + with pytest.warns(DeprecationWarning, match=r"Parameter `ftp_conn_id` is deprecated"): assert SFTPHook(ftp_conn_id="ftp_default").ssh_conn_id == "ftp_default" # If both are provided, ftp_conn_id will be used but would show a deprecation warning. - with self.assertWarnsRegex(DeprecationWarning, "Parameter `ftp_conn_id` is deprecated"): + with pytest.warns(DeprecationWarning, match=r"Parameter `ftp_conn_id` is deprecated"): assert ( SFTPHook(ftp_conn_id="ftp_default", ssh_conn_id="sftp_default").ssh_conn_id == "ftp_default" ) @@ -397,29 +397,29 @@ def test_valid_ssh_hook(self, mock_get_connection): def test_get_suffix_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt") - self.assertTrue(output, TMP_FILE_FOR_TESTS) + assert output == TMP_FILE_FOR_TESTS def test_get_prefix_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "test*") - self.assertTrue(output, TMP_FILE_FOR_TESTS) + assert output == TMP_FILE_FOR_TESTS def test_get_pattern_not_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.text") - self.assertFalse(output) + assert output == "" def test_get_several_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.log") - self.assertEqual(LOG_FILE_FOR_TESTS, output) + assert LOG_FILE_FOR_TESTS == output def test_get_first_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt") - self.assertEqual(TMP_FILE_FOR_TESTS, output) + assert TMP_FILE_FOR_TESTS == output def test_get_middle_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt") - self.assertEqual(ANOTHER_FILE_FOR_TESTS, output) + assert ANOTHER_FILE_FOR_TESTS == output - def tearDown(self): + def teardown_method(self): shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]: os.remove(os.path.join(TMP_PATH, file_name)) diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py index 575f739477d3a..585b10cf892e3 100644 --- a/tests/providers/sftp/sensors/test_sftp.py +++ b/tests/providers/sftp/sensors/test_sftp.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from datetime import datetime from unittest.mock import patch @@ -28,7 +27,7 @@ from airflow.providers.sftp.sensors.sftp import SFTPSensor -class TestSFTPSensor(unittest.TestCase): +class TestSFTPSensor: @patch("airflow.providers.sftp.sensors.sftp.SFTPHook") def test_file_present(self, sftp_hook_mock): sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000" diff --git a/tests/providers/singularity/operators/test_singularity.py b/tests/providers/singularity/operators/test_singularity.py index 7ac389e0bee8a..34508e7b01ba1 100644 --- a/tests/providers/singularity/operators/test_singularity.py +++ b/tests/providers/singularity/operators/test_singularity.py @@ -17,18 +17,16 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest -from parameterized import parameterized from spython.instance import Instance from airflow.exceptions import AirflowException from airflow.providers.singularity.operators.singularity import SingularityOperator -class SingularityOperatorTestCase(unittest.TestCase): +class TestSingularityOperator: @mock.patch("airflow.providers.singularity.operators.singularity.Client") def test_execute(self, client_mock): instance = mock.Mock( @@ -55,12 +53,7 @@ def test_execute(self, client_mock): instance.start.assert_called_once_with() instance.stop.assert_called_once_with() - @parameterized.expand( - [ - ("",), - (None,), - ] - ) + @pytest.mark.parametrize("command", [pytest.param("", id="empty"), pytest.param(None, id="none")]) def test_command_is_required(self, command): task = SingularityOperator(task_id="task-id", image="docker://busybox", command=command) with pytest.raises(AirflowException, match="You must define a command."): @@ -95,7 +88,8 @@ def test_image_should_be_pulled_when_not_exists(self, client_mock): client_mock.pull.assert_called_once_with("docker://busybox", stream=True, pull_folder="/tmp") client_mock.execute.assert_called_once_with(mock.ANY, "echo hello", return_result=True) - @parameterized.expand( + @pytest.mark.parametrize( + "volumes, expected_options", [ ( None, @@ -117,10 +111,10 @@ def test_image_should_be_pulled_when_not_exists(self, client_mock): ["AAA", "BBB", "CCC"], ["--bind", "AAA", "--bind", "BBB", "--bind", "CCC"], ), - ] + ], ) @mock.patch("airflow.providers.singularity.operators.singularity.Client") - def test_bind_options(self, volumes, expected_options, client_mock): + def test_bind_options(self, client_mock, volumes, expected_options): instance = mock.Mock( autospec=Instance, **{ @@ -145,7 +139,8 @@ def test_bind_options(self, volumes, expected_options, client_mock): "docker://busybox", options=expected_options, args=None, start=False ) - @parameterized.expand( + @pytest.mark.parametrize( + "working_dir, expected_working_dir", [ ( None, @@ -159,10 +154,10 @@ def test_bind_options(self, volumes, expected_options, client_mock): "/work-dir/", ["--workdir", "/work-dir/"], ), - ] + ], ) @mock.patch("airflow.providers.singularity.operators.singularity.Client") - def test_working_dir(self, working_dir, expected_working_dir, client_mock): + def test_working_dir(self, client_mock, working_dir, expected_working_dir): instance = mock.Mock( autospec=Instance, **{ diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 7cd2d040cfbc7..dde2f9addeb60 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -37,9 +36,8 @@ TEST_DAG_ID = "unit_test_dag" -class TestSnowflakeOperator(unittest.TestCase): - def setUp(self): - super().setUp() +class TestSnowflakeOperator: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py index 821c883913b92..f0e7a61fd5311 100644 --- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py +++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py @@ -16,13 +16,12 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator -class TestCopyFromExternalStageToSnowflake(unittest.TestCase): +class TestCopyFromExternalStageToSnowflake: @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") def test_execute(self, mock_hook): CopyFromExternalStageToSnowflakeOperator( diff --git a/tests/providers/sqlite/hooks/test_sqlite.py b/tests/providers/sqlite/hooks/test_sqlite.py index 0a1dde81e2fd1..af61442ece4fe 100644 --- a/tests/providers/sqlite/hooks/test_sqlite.py +++ b/tests/providers/sqlite/hooks/test_sqlite.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import patch @@ -27,8 +26,8 @@ from airflow.providers.sqlite.hooks.sqlite import SqliteHook -class TestSqliteHookConn(unittest.TestCase): - def setUp(self): +class TestSqliteHookConn: + def setup_method(self): self.connection = Connection(host="host") @@ -52,8 +51,8 @@ def test_get_conn_non_default_id(self, mock_connect): self.db_hook.get_connection.assert_called_once_with("non_default") -class TestSqliteHook(unittest.TestCase): - def setUp(self): +class TestSqliteHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() diff --git a/tests/providers/sqlite/operators/test_sqlite.py b/tests/providers/sqlite/operators/test_sqlite.py index d0de4ca32ff88..1584bd909d746 100644 --- a/tests/providers/sqlite/operators/test_sqlite.py +++ b/tests/providers/sqlite/operators/test_sqlite.py @@ -17,8 +17,6 @@ # under the License. from __future__ import annotations -import unittest - import pytest from airflow.models.dag import DAG @@ -32,13 +30,13 @@ @pytest.mark.backend("sqlite") -class TestSqliteOperator(unittest.TestCase): - def setUp(self): +class TestSqliteOperator: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - def tearDown(self): + def teardown_method(self): tables_to_drop = ["test_airflow", "test_airflow2"] from airflow.providers.sqlite.hooks.sqlite import SqliteHook diff --git a/tests/providers/ssh/hooks/test_ssh.py b/tests/providers/ssh/hooks/test_ssh.py index 324db86b744d3..6448d88efe497 100644 --- a/tests/providers/ssh/hooks/test_ssh.py +++ b/tests/providers/ssh/hooks/test_ssh.py @@ -21,13 +21,11 @@ import random import string import textwrap -import unittest from io import StringIO from unittest import mock import paramiko import pytest -from parameterized import parameterized from airflow import settings from airflow.exceptions import AirflowException @@ -83,7 +81,7 @@ def generate_host_key(pkey: paramiko.PKey): TEST_CIPHERS = ["aes128-ctr", "aes192-ctr", "aes256-ctr"] -class TestSSHHook(unittest.TestCase): +class TestSSHHook: CONN_SSH_WITH_NO_EXTRA = "ssh_with_no_extra" CONN_SSH_WITH_PRIVATE_KEY_EXTRA = "ssh_with_private_key_extra" CONN_SSH_WITH_PRIVATE_KEY_ECDSA_EXTRA = "ssh_with_private_key_ecdsa_extra" @@ -112,7 +110,7 @@ class TestSSHHook(unittest.TestCase): ) @classmethod - def tearDownClass(cls) -> None: + def teardown_class(cls) -> None: with create_session() as session: conns_to_reset = [ cls.CONN_SSH_WITH_NO_EXTRA, @@ -139,7 +137,7 @@ def tearDownClass(cls) -> None: session.commit() @classmethod - def setUpClass(cls) -> None: + def setup_class(cls) -> None: db.merge_conn( Connection( conn_id=cls.CONN_SSH_WITH_NO_EXTRA, @@ -741,7 +739,8 @@ def test_ssh_connection_with_timeout_extra_and_conn_timeout_extra(self, ssh_mock look_for_keys=True, ) - @parameterized.expand( + @pytest.mark.parametrize( + "timeout, conn_timeout, timeoutextra, conn_timeoutextra, expected_value", [ (TEST_TIMEOUT, TEST_CONN_TIMEOUT, True, True, TEST_CONN_TIMEOUT), (TEST_TIMEOUT, TEST_CONN_TIMEOUT, True, False, TEST_CONN_TIMEOUT), @@ -759,11 +758,11 @@ def test_ssh_connection_with_timeout_extra_and_conn_timeout_extra(self, ssh_mock (None, None, True, False, TEST_TIMEOUT), (None, None, False, True, TEST_CONN_TIMEOUT), (None, None, False, False, 10), - ] + ], ) @mock.patch("airflow.providers.ssh.hooks.ssh.paramiko.SSHClient") def test_ssh_connection_with_all_timeout_param_and_extra_combinations( - self, timeout, conn_timeout, timeoutextra, conn_timeoutextra, expected_value, ssh_mock + self, ssh_mock, timeout, conn_timeout, timeoutextra, conn_timeoutextra, expected_value ): if timeoutextra and conn_timeoutextra: diff --git a/tests/providers/tableau/hooks/test_tableau.py b/tests/providers/tableau/hooks/test_tableau.py index fd69cb4c6beda..b463fce31017c 100644 --- a/tests/providers/tableau/hooks/test_tableau.py +++ b/tests/providers/tableau/hooks/test_tableau.py @@ -16,25 +16,21 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import MagicMock, patch -from parameterized import parameterized +import pytest from airflow import configuration, models from airflow.providers.tableau.hooks.tableau import TableauHook, TableauJobFinishCode from airflow.utils import db -class TestTableauHook(unittest.TestCase): +class TestTableauHook: """ Test class for TableauHook """ - def setUp(self): - """ - setup - """ + def setup_method(self): configuration.conf.load_test_config() db.merge_conn( @@ -222,15 +218,16 @@ def test_get_all(self, mock_pager, mock_server, mock_tableau_auth): mock_pager.assert_called_once_with(mock_server.return_value.jobs.get) - @parameterized.expand( + @pytest.mark.parametrize( + "finish_code, expected_status", [ - (0, TableauJobFinishCode.SUCCESS), - (1, TableauJobFinishCode.ERROR), - (2, TableauJobFinishCode.CANCELED), - ] + pytest.param(0, TableauJobFinishCode.SUCCESS, id="SUCCESS"), + pytest.param(1, TableauJobFinishCode.ERROR, id="ERROR"), + pytest.param(2, TableauJobFinishCode.CANCELED, id="CANCELED"), + ], ) @patch("airflow.providers.tableau.hooks.tableau.Server") - def test_get_job_status(self, finish_code, expected_status, mock_tableau_server): + def test_get_job_status(self, mock_tableau_server, finish_code, expected_status): """ Test get job status """ diff --git a/tests/providers/tableau/operators/test_tableau.py b/tests/providers/tableau/operators/test_tableau.py index a97abb81b1951..994aafaaf023d 100644 --- a/tests/providers/tableau/operators/test_tableau.py +++ b/tests/providers/tableau/operators/test_tableau.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch import pytest @@ -26,16 +25,12 @@ from airflow.providers.tableau.operators.tableau import TableauOperator -class TestTableauOperator(unittest.TestCase): +class TestTableauOperator: """ Test class for TableauOperator """ - def setUp(self): - """ - setup - """ - + def setup_method(self): self.mocked_workbooks = [] self.mock_datasources = [] diff --git a/tests/providers/tableau/sensors/test_tableau.py b/tests/providers/tableau/sensors/test_tableau.py index 22159991a2f3d..0c873e1f33eeb 100644 --- a/tests/providers/tableau/sensors/test_tableau.py +++ b/tests/providers/tableau/sensors/test_tableau.py @@ -16,11 +16,9 @@ # under the License. from __future__ import annotations -import unittest from unittest.mock import Mock, patch import pytest -from parameterized import parameterized from airflow.providers.tableau.sensors.tableau import ( TableauJobFailedException, @@ -29,12 +27,12 @@ ) -class TestTableauJobStatusSensor(unittest.TestCase): +class TestTableauJobStatusSensor: """ Test Class for JobStatusSensor """ - def setUp(self): + def setup_method(self): self.kwargs = {"job_id": "job_2", "site_id": "test_site", "task_id": "task", "dag": None} @patch("airflow.providers.tableau.sensors.tableau.TableauHook") @@ -51,9 +49,15 @@ def test_poke(self, mock_tableau_hook): assert job_finished mock_tableau_hook.get_job_status.assert_called_once_with(job_id=sensor.job_id) - @parameterized.expand([(TableauJobFinishCode.ERROR,), (TableauJobFinishCode.CANCELED,)]) + @pytest.mark.parametrize( + "finish_code", + [ + pytest.param(TableauJobFinishCode.ERROR, id="ERROR"), + pytest.param(TableauJobFinishCode.CANCELED, id="CANCELED"), + ], + ) @patch("airflow.providers.tableau.sensors.tableau.TableauHook") - def test_poke_failed(self, finish_code, mock_tableau_hook): + def test_poke_failed(self, mock_tableau_hook, finish_code): """ Test poke failed """ diff --git a/tests/providers/telegram/hooks/test_telegram.py b/tests/providers/telegram/hooks/test_telegram.py index a915f6d3285ee..de4722c78f43e 100644 --- a/tests/providers/telegram/hooks/test_telegram.py +++ b/tests/providers/telegram/hooks/test_telegram.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -31,8 +30,8 @@ TELEGRAM_TOKEN = "dummy token" -class TestTelegramHook(unittest.TestCase): - def setUp(self): +class TestTelegramHook: + def setup_method(self): db.merge_conn( Connection( conn_id="telegram-webhook-without-token", diff --git a/tests/providers/telegram/operators/test_telegram.py b/tests/providers/telegram/operators/test_telegram.py index 6eef156602fc7..f375ca3ad4557 100644 --- a/tests/providers/telegram/operators/test_telegram.py +++ b/tests/providers/telegram/operators/test_telegram.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock import pytest @@ -31,8 +30,8 @@ TELEGRAM_TOKEN = "xxx:xxx" -class TestTelegramOperator(unittest.TestCase): - def setUp(self): +class TestTelegramOperator: + def setup_method(self): db.merge_conn( Connection( conn_id="telegram_default", diff --git a/tests/providers/trino/hooks/test_trino.py b/tests/providers/trino/hooks/test_trino.py index 9f17ec69a37db..ed2341e5642f7 100644 --- a/tests/providers/trino/hooks/test_trino.py +++ b/tests/providers/trino/hooks/test_trino.py @@ -19,12 +19,10 @@ import json import re -import unittest from unittest import mock from unittest.mock import patch import pytest -from parameterized import parameterized from trino.transaction import IsolationLevel from airflow import AirflowException @@ -179,18 +177,19 @@ def test_get_conn_client_tags(self, mock_connect, mock_get_connection): self.assert_connection_called_with(mock_connect, client_tags=extras["client_tags"]) - @parameterized.expand( + @pytest.mark.parametrize( + "current_verify, expected_verify", [ ("False", False), ("false", False), ("true", True), ("true", True), ("/tmp/cert.crt", "/tmp/cert.crt"), - ] + ], ) @patch(HOOK_GET_CONNECTION) @patch(TRINO_DBAPI_CONNECT) - def test_get_conn_verify(self, current_verify, expected_verify, mock_connect, mock_get_connection): + def test_get_conn_verify(self, mock_connect, mock_get_connection, current_verify, expected_verify): extras = {"verify": current_verify} self.set_get_connection_return_value(mock_get_connection, extra=json.dumps(extras)) TrinoHook().get_conn() @@ -224,10 +223,8 @@ def assert_connection_called_with( ) -class TestTrinoHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestTrinoHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur @@ -312,8 +309,8 @@ def test_connection_failure(self, mock_conn): assert msg == "Test" -class TestTrinoHookIntegration(unittest.TestCase): - @pytest.mark.integration("trino") +@pytest.mark.integration("trino") +class TestTrinoHookIntegration: @mock.patch.dict("os.environ", AIRFLOW_CONN_TRINO_DEFAULT="trino://airflow@trino:8080/") def test_should_record_records(self): hook = TrinoHook() @@ -321,7 +318,6 @@ def test_should_record_records(self): records = hook.get_records(sql) assert [["Customer#000000001"], ["Customer#000000002"], ["Customer#000000003"]] == records - @pytest.mark.integration("trino") @pytest.mark.integration("kerberos") def test_should_record_records_with_kerberos_auth(self): conn_url = ( diff --git a/tests/providers/trino/operators/test_trino.py b/tests/providers/trino/operators/test_trino.py index 3caa2dad8ee6c..8fef756ecf8f2 100644 --- a/tests/providers/trino/operators/test_trino.py +++ b/tests/providers/trino/operators/test_trino.py @@ -17,26 +17,28 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock +import pytest + from airflow.providers.trino.operators.trino import TrinoOperator TRINO_CONN_ID = "test_trino" TASK_ID = "test_trino_task" -class TestTrinoOperator(unittest.TestCase): +class TestTrinoOperator: @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") def test_execute(self, mock_get_db_hook): """Asserts that the run method is called when a TrinoOperator task is executed""" - op = TrinoOperator( - task_id=TASK_ID, - sql="SELECT 1;", - trino_conn_id=TRINO_CONN_ID, - handler=list, - ) + with pytest.warns(DeprecationWarning, match="This class is deprecated.*"): + op = TrinoOperator( + task_id=TASK_ID, + sql="SELECT 1;", + trino_conn_id=TRINO_CONN_ID, + handler=list, + ) op.execute(None) mock_get_db_hook.return_value.run.assert_called_once_with( diff --git a/tests/providers/trino/transfers/test_gcs_trino.py b/tests/providers/trino/transfers/test_gcs_trino.py index 624d7167b281f..ce5d2f8d10aea 100644 --- a/tests/providers/trino/transfers/test_gcs_trino.py +++ b/tests/providers/trino/transfers/test_gcs_trino.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.trino.transfers.gcs_to_trino import GCSToTrinoOperator @@ -33,7 +32,7 @@ SCHEMA_JSON = "path/to/file.json" -class TestGCSToTrinoOperator(unittest.TestCase): +class TestGCSToTrinoOperator: @mock.patch("airflow.providers.trino.transfers.gcs_to_trino.TrinoHook") @mock.patch("airflow.providers.trino.transfers.gcs_to_trino.GCSHook") @mock.patch("airflow.providers.trino.transfers.gcs_to_trino.NamedTemporaryFile") diff --git a/tests/providers/vertica/hooks/test_vertica.py b/tests/providers/vertica/hooks/test_vertica.py index 93d20cd2384b9..e78c2a0c5c813 100644 --- a/tests/providers/vertica/hooks/test_vertica.py +++ b/tests/providers/vertica/hooks/test_vertica.py @@ -17,7 +17,6 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from unittest.mock import patch @@ -25,10 +24,8 @@ from airflow.providers.vertica.hooks.vertica import VerticaHook -class TestVerticaHookConn(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestVerticaHookConn: + def setup_method(self): self.connection = Connection( login="login", password="password", @@ -51,10 +48,8 @@ def test_get_conn(self, mock_connect): ) -class TestVerticaHook(unittest.TestCase): - def setUp(self): - super().setUp() - +class TestVerticaHook: + def setup_method(self): self.cur = mock.MagicMock(rowcount=0) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur diff --git a/tests/providers/vertica/operators/test_vertica.py b/tests/providers/vertica/operators/test_vertica.py index 836f324cb8289..3fd6aa3e52d84 100644 --- a/tests/providers/vertica/operators/test_vertica.py +++ b/tests/providers/vertica/operators/test_vertica.py @@ -17,14 +17,13 @@ # under the License. from __future__ import annotations -import unittest from unittest import mock from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.vertica.operators.vertica import VerticaOperator -class TestVerticaOperator(unittest.TestCase): +class TestVerticaOperator: @mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook") def test_execute(self, mock_get_db_hook): sql = "select a, b, c" diff --git a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py index b93a499a13fa6..cf436a9b8c02d 100644 --- a/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py +++ b/tests/providers/yandex/hooks/test_yandexcloud_dataproc.py @@ -17,9 +17,10 @@ from __future__ import annotations import json -import unittest from unittest.mock import patch +import pytest + from airflow.models import Connection try: @@ -64,14 +65,14 @@ HAS_CREDENTIALS = OAUTH_TOKEN != "my_oauth_token" -@unittest.skipIf(yandexcloud is None, "Skipping Yandex.Cloud hook test: no yandexcloud module") -class TestYandexCloudDataprocHook(unittest.TestCase): +@pytest.mark.skipif(yandexcloud is None, reason="Skipping Yandex.Cloud hook test: no yandexcloud module") +class TestYandexCloudDataprocHook: def _init_hook(self): with patch("airflow.hooks.base.BaseHook.get_connection") as get_connection_mock: get_connection_mock.return_value = self.connection self.hook = DataprocHook() - def setUp(self): + def setup_method(self): self.connection = Connection(extra=json.dumps({"oauth": OAUTH_TOKEN})) self._init_hook() diff --git a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py index 3b3aa69a5d059..8ffe2050a037c 100644 --- a/tests/providers/yandex/operators/test_yandexcloud_dataproc.py +++ b/tests/providers/yandex/operators/test_yandexcloud_dataproc.py @@ -17,7 +17,6 @@ from __future__ import annotations import datetime -from unittest import TestCase from unittest.mock import MagicMock, call, patch from airflow.models.dag import DAG @@ -64,8 +63,8 @@ LOG_GROUP_ID = "my_log_group_id" -class DataprocClusterCreateOperatorTest(TestCase): - def setUp(self): +class TestDataprocClusterCreateOperator: + def setup_method(self): dag_id = "test_dag" self.dag = DAG( dag_id, From 054f104b3c1b6f7c8fb726f1e97afe52369612b1 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Tue, 15 Nov 2022 22:52:11 +0400 Subject: [PATCH 2/7] Fix SFTPHook tests --- tests/providers/sftp/hooks/test_sftp.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py index d2855b601fda6..4d7f7bb5623d0 100644 --- a/tests/providers/sftp/hooks/test_sftp.py +++ b/tests/providers/sftp/hooks/test_sftp.py @@ -178,11 +178,11 @@ def test_no_host_key_check_default(self, get_connection): @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_enabled(self, get_connection): - connection = Connection(login="login", host="host", extra='{"no_host_key_check": false}') + connection = Connection(login="login", host="host", extra='{"no_host_key_check": true}') get_connection.return_value = connection hook = SFTPHook() - assert hook.no_host_key_check is False + assert hook.no_host_key_check is True @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") def test_no_host_key_check_disabled(self, get_connection): @@ -385,23 +385,27 @@ def test_invalid_ssh_hook(self, mock_get_connection): with pytest.raises(AirflowException, match="ssh_hook must be an instance of SSHHook"): connection = Connection(conn_id="sftp_default", login="root", host="localhost") mock_get_connection.return_value = connection - SFTPHook(ssh_hook="invalid_hook") # type: ignore + with pytest.warns(DeprecationWarning, match=r"Parameter `ssh_hook` is deprecated.*"): + SFTPHook(ssh_hook="invalid_hook") @mock.patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_connection") def test_valid_ssh_hook(self, mock_get_connection): connection = Connection(conn_id="sftp_test", login="root", host="localhost") mock_get_connection.return_value = connection - hook = SFTPHook(ssh_hook=SSHHook(ssh_conn_id="sftp_test")) + with pytest.warns(DeprecationWarning, match=r"Parameter `ssh_hook` is deprecated.*"): + hook = SFTPHook(ssh_hook=SSHHook(ssh_conn_id="sftp_test")) assert hook.ssh_conn_id == "sftp_test" assert isinstance(hook.get_conn(), paramiko.SFTPClient) def test_get_suffix_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.txt") - assert output == TMP_FILE_FOR_TESTS + # In CI files might have different name, so we check that file found rather than actual name + assert output, TMP_FILE_FOR_TESTS def test_get_prefix_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "test*") - assert output == TMP_FILE_FOR_TESTS + # In CI files might have different name, so we check that file found rather than actual name + assert output, TMP_FILE_FOR_TESTS def test_get_pattern_not_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.text") @@ -409,15 +413,15 @@ def test_get_pattern_not_match(self): def test_get_several_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*.log") - assert LOG_FILE_FOR_TESTS == output + assert output == LOG_FILE_FOR_TESTS def test_get_first_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "test_*.txt") - assert TMP_FILE_FOR_TESTS == output + assert output == TMP_FILE_FOR_TESTS def test_get_middle_pattern_match(self): output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt") - assert ANOTHER_FILE_FOR_TESTS == output + assert output == ANOTHER_FILE_FOR_TESTS def teardown_method(self): shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS)) From 88ae05f2f64483b45be64379ef57dbfb259d03f2 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Tue, 15 Nov 2022 23:19:20 +0400 Subject: [PATCH 3/7] Fix MySqlHook Tests --- tests/providers/mysql/hooks/test_mysql.py | 5 ++--- tests/test_utils/asserts.py | 20 ++++++++++++++++++-- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/tests/providers/mysql/hooks/test_mysql.py b/tests/providers/mysql/hooks/test_mysql.py index 3a2724cd8f406..3fbaddcff46b9 100644 --- a/tests/providers/mysql/hooks/test_mysql.py +++ b/tests/providers/mysql/hooks/test_mysql.py @@ -30,6 +30,7 @@ from airflow.models.dag import DAG from airflow.providers.mysql.hooks.mysql import MySqlHook from airflow.utils import timezone +from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces SSL_DICT = {"cert": "/tmp/client-cert.pem", "ca": "/tmp/server-ca.pem", "key": "/tmp/client-key.pem"} @@ -436,11 +437,9 @@ def test_mysql_hook_test_bulk_dump_mock(self, mock_get_conn, client): tmp_file = "/path/to/output/file" hook.bulk_dump(table, tmp_file) - from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces - assert mock_execute.call_count == 1 query = f""" SELECT * INTO OUTFILE '{tmp_file}' FROM {table} """ - assert_equal_ignore_multiple_spaces(self, mock_execute.call_args[0][0], query) + assert_equal_ignore_multiple_spaces(None, mock_execute.call_args[0][0], query) diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py index e16123453e21a..9f3eab630b680 100644 --- a/tests/test_utils/asserts.py +++ b/tests/test_utils/asserts.py @@ -19,22 +19,38 @@ import logging import re import traceback +import warnings from collections import Counter from contextlib import contextmanager +from typing import TYPE_CHECKING from sqlalchemy import event # Long import to not create a copy of the reference, but to refer to one place. import airflow.settings +if TYPE_CHECKING: + from unittest import TestCase + log = logging.getLogger(__name__) -def assert_equal_ignore_multiple_spaces(case, first, second, msg=None): +def assert_equal_ignore_multiple_spaces(case: TestCase | None, first, second, msg=None): def _trim(s): return re.sub(r"\s+", " ", s.strip()) - return case.assertEqual(_trim(first), _trim(second), msg) + if case: + warnings.warn( + "Passing `case` has no effect and will be remove in the future. " + "Please set to `None` for avoid this warning.", + FutureWarning, + stacklevel=3, + ) + + if not msg: + assert _trim(first) == _trim(second) + else: + assert _trim(first) == _trim(second), msg class CountQueries: From d294af6bb0a8c09ee94d27f4ff2aa2bdefe30b41 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Wed, 16 Nov 2022 21:11:19 +0400 Subject: [PATCH 4/7] Fix TestRedisPublishOperator --- tests/providers/redis/operators/test_redis_publish.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/providers/redis/operators/test_redis_publish.py b/tests/providers/redis/operators/test_redis_publish.py index cc1468db9ee39..c5ea8a65bd626 100644 --- a/tests/providers/redis/operators/test_redis_publish.py +++ b/tests/providers/redis/operators/test_redis_publish.py @@ -31,7 +31,7 @@ @pytest.mark.integration("redis") class TestRedisPublishOperator: - def setUp(self): + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG("test_redis_dag_id", default_args=args) From 802046484666e1773783af920241a85e3fc97f53 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Wed, 16 Nov 2022 21:32:02 +0400 Subject: [PATCH 5/7] get rid of TestHiveEnvironment --- .../providers/common/sql/sensors/test_sql.py | 11 ++++------- .../mysql/transfers/test_presto_to_mysql.py | 19 ++++++++++++------- .../mysql/transfers/test_trino_to_mysql.py | 19 ++++++++++++------- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/providers/common/sql/sensors/test_sql.py b/tests/providers/common/sql/sensors/test_sql.py index 77665f1c84a12..912478e16d462 100644 --- a/tests/providers/common/sql/sensors/test_sql.py +++ b/tests/providers/common/sql/sensors/test_sql.py @@ -18,7 +18,6 @@ from __future__ import annotations import os -import unittest from unittest import mock import pytest @@ -28,15 +27,13 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.sensors.sql import SqlSensor from airflow.utils.timezone import datetime -from tests.providers.apache.hive import TestHiveEnvironment DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = "unit_test_sql_dag" -class TestSqlSensor(TestHiveEnvironment): - def setUp(self): - super().setUp() +class TestSqlSensor: + def setup_method(self): args = {"owner": "airflow", "start_date": DEFAULT_DATE} self.dag = DAG(TEST_DAG_ID, default_args=args) @@ -245,8 +242,8 @@ def test_sql_sensor_postgres_poke_invalid_success(self, mock_hook): op.poke(None) assert "self.success is present, but not callable -> [1]" == str(ctx.value) - @unittest.skipIf( - "AIRFLOW_RUNALL_TESTS" not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + @pytest.mark.skipif( + "AIRFLOW_RUNALL_TESTS" not in os.environ, reason="Skipped because AIRFLOW_RUNALL_TESTS is not set" ) def test_sql_sensor_presto(self): op = SqlSensor( diff --git a/tests/providers/mysql/transfers/test_presto_to_mysql.py b/tests/providers/mysql/transfers/test_presto_to_mysql.py index f5d572b7be792..7249cfe4ecad7 100644 --- a/tests/providers/mysql/transfers/test_presto_to_mysql.py +++ b/tests/providers/mysql/transfers/test_presto_to_mysql.py @@ -18,21 +18,26 @@ from __future__ import annotations import os -import unittest +from datetime import datetime from unittest.mock import patch +import pytest + +from airflow.models.dag import DAG from airflow.providers.mysql.transfers.presto_to_mysql import PrestoToMySqlOperator -from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + +DEFAULT_DATE = datetime(2022, 1, 1) -class TestPrestoToMySqlTransfer(TestHiveEnvironment): - def setUp(self): +class TestPrestoToMySqlTransfer: + def setup_method(self): self.kwargs = dict( sql="sql", mysql_table="mysql_table", task_id="test_presto_to_mysql_transfer", ) - super().setUp() + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_presto_to_mysql_transfer", default_args=args) @patch("airflow.providers.mysql.transfers.presto_to_mysql.MySqlHook") @patch("airflow.providers.mysql.transfers.presto_to_mysql.PrestoHook") @@ -57,8 +62,8 @@ def test_execute_with_mysql_preoperator(self, mock_presto_hook, mock_mysql_hook) table=self.kwargs["mysql_table"], rows=mock_presto_hook.return_value.get_records.return_value ) - @unittest.skipIf( - "AIRFLOW_RUNALL_TESTS" not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + @pytest.mark.skipif( + "AIRFLOW_RUNALL_TESTS" not in os.environ, reason="Skipped because AIRFLOW_RUNALL_TESTS is not set" ) def test_presto_to_mysql(self): op = PrestoToMySqlOperator( diff --git a/tests/providers/mysql/transfers/test_trino_to_mysql.py b/tests/providers/mysql/transfers/test_trino_to_mysql.py index 6a7f01ad2bf16..390c84729b2b8 100644 --- a/tests/providers/mysql/transfers/test_trino_to_mysql.py +++ b/tests/providers/mysql/transfers/test_trino_to_mysql.py @@ -18,21 +18,26 @@ from __future__ import annotations import os -import unittest +from datetime import datetime from unittest.mock import patch +import pytest + +from airflow.models.dag import DAG from airflow.providers.mysql.transfers.trino_to_mysql import TrinoToMySqlOperator -from tests.providers.apache.hive import DEFAULT_DATE, TestHiveEnvironment + +DEFAULT_DATE = datetime(2022, 1, 1) -class TestTrinoToMySqlTransfer(TestHiveEnvironment): - def setUp(self): +class TestTrinoToMySqlTransfer: + def setup_method(self): self.kwargs = dict( sql="sql", mysql_table="mysql_table", task_id="test_trino_to_mysql_transfer", ) - super().setUp() + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_trino_to_mysql_transfer", default_args=args) @patch("airflow.providers.mysql.transfers.trino_to_mysql.MySqlHook") @patch("airflow.providers.mysql.transfers.trino_to_mysql.TrinoHook") @@ -57,8 +62,8 @@ def test_execute_with_mysql_preoperator(self, mock_trino_hook, mock_mysql_hook): table=self.kwargs["mysql_table"], rows=mock_trino_hook.return_value.get_records.return_value ) - @unittest.skipIf( - "AIRFLOW_RUNALL_TESTS" not in os.environ, "Skipped because AIRFLOW_RUNALL_TESTS is not set" + @pytest.mark.skipif( + "AIRFLOW_RUNALL_TESTS" not in os.environ, reason="Skipped because AIRFLOW_RUNALL_TESTS is not set" ) def test_trino_to_mysql(self): op = TrinoToMySqlOperator( From 8b989a65c151cf4fc29b945a2a4cf2cc43cc0208 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Wed, 16 Nov 2022 21:37:30 +0400 Subject: [PATCH 6/7] Remove import unittest --- .../snowflake/hooks/test_snowflake.py | 63 +++++++------------ 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/tests/providers/snowflake/hooks/test_snowflake.py b/tests/providers/snowflake/hooks/test_snowflake.py index 87c661043a643..f175ec5f087ac 100644 --- a/tests/providers/snowflake/hooks/test_snowflake.py +++ b/tests/providers/snowflake/hooks/test_snowflake.py @@ -18,7 +18,6 @@ from __future__ import annotations import json -import unittest from copy import deepcopy from pathlib import Path from typing import Any @@ -268,9 +267,7 @@ class TestPytestSnowflakeHook: def test_hook_should_support_prepare_basic_conn_params_and_uri( self, connection_kwargs, expected_uri, expected_conn_params ): - with unittest.mock.patch.dict( - "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert SnowflakeHook(snowflake_conn_id="test_conn").get_uri() == expected_uri assert SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() == expected_conn_params @@ -289,9 +286,7 @@ def test_get_conn_params_should_support_private_auth_in_connection( "private_key_content": str(encrypted_temporary_private_key.read_text()), }, } - with unittest.mock.patch.dict( - "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() @pytest.mark.parametrize("include_params", [True, False]) @@ -308,7 +303,7 @@ def test_hook_param_beats_extra(self, include_params): session_parameters="session_parameters", ) extras = {k: f"{v}_extra" for k, v in hook_params.items()} - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(conn_type="any", extra=json.dumps(extras)).get_uri(), ): @@ -336,7 +331,7 @@ def test_extra_short_beats_long(self, include_unprefixed): role="role", ) extras_prefixed = {f"extra__snowflake__{k}": f"{v}_prefixed" for k, v in extras.items()} - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection( conn_type="any", @@ -369,9 +364,7 @@ def test_get_conn_params_should_support_private_auth_with_encrypted_key( "private_key_file": str(encrypted_temporary_private_key), }, } - with unittest.mock.patch.dict( - "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() def test_get_conn_params_should_support_private_auth_with_unencrypted_key( @@ -389,23 +382,19 @@ def test_get_conn_params_should_support_private_auth_with_unencrypted_key( "private_key_file": str(non_encrypted_temporary_private_key), }, } - with unittest.mock.patch.dict( - "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() connection_kwargs["password"] = "" - with unittest.mock.patch.dict( - "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ): + with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()): assert "private_key" in SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() connection_kwargs["password"] = _PASSWORD - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() ), pytest.raises(TypeError, match="Password was given but private key is not encrypted."): SnowflakeHook(snowflake_conn_id="test_conn")._get_conn_params() def test_should_add_partner_info(self): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri(), AIRFLOW_SNOWFLAKE_PARTNER="PARTNER_NAME", @@ -416,20 +405,18 @@ def test_should_add_partner_info(self): ) def test_get_conn_should_call_connect(self): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri() - ), unittest.mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector: + ), mock.patch("airflow.providers.snowflake.hooks.snowflake.connector") as mock_connector: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_conn() mock_connector.connect.assert_called_once_with(**hook._get_conn_params()) assert mock_connector.connect.return_value == conn def test_get_sqlalchemy_engine_should_support_pass_auth(self): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri() - ), unittest.mock.patch( - "airflow.providers.snowflake.hooks.snowflake.create_engine" - ) as mock_create_engine: + ), mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_sqlalchemy_engine() mock_create_engine.assert_called_once_with( @@ -442,11 +429,9 @@ def test_get_sqlalchemy_engine_should_support_insecure_mode(self): connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) connection_kwargs["extra"]["extra__snowflake__insecure_mode"] = "True" - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ), unittest.mock.patch( - "airflow.providers.snowflake.hooks.snowflake.create_engine" - ) as mock_create_engine: + ), mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_sqlalchemy_engine() mock_create_engine.assert_called_once_with( @@ -460,11 +445,9 @@ def test_get_sqlalchemy_engine_should_support_session_parameters(self): connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM": "AA", "TEST_PARAM_B": 123} - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ), unittest.mock.patch( - "airflow.providers.snowflake.hooks.snowflake.create_engine" - ) as mock_create_engine: + ), mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_sqlalchemy_engine() mock_create_engine.assert_called_once_with( @@ -479,18 +462,16 @@ def test_get_sqlalchemy_engine_should_support_private_key_auth(self, non_encrypt connection_kwargs["password"] = "" connection_kwargs["extra"]["private_key_file"] = str(non_encrypted_temporary_private_key) - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri() - ), unittest.mock.patch( - "airflow.providers.snowflake.hooks.snowflake.create_engine" - ) as mock_create_engine: + ), mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine: hook = SnowflakeHook(snowflake_conn_id="test_conn") conn = hook.get_sqlalchemy_engine() assert "private_key" in mock_create_engine.call_args[1]["connect_args"] assert mock_create_engine.return_value == conn def test_hook_parameters_should_take_precedence(self): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri() ): hook = SnowflakeHook( @@ -555,7 +536,7 @@ def test_run_storing_query_ids_extra(self, mock_conn, sql, expected_sql, expecte @mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first") def test_connection_success(self, mock_get_first): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**BASE_CONNECTION_KWARGS).get_uri() ): hook = SnowflakeHook() @@ -570,7 +551,7 @@ def test_connection_success(self, mock_get_first): side_effect=Exception("Connection Errors"), ) def test_connection_failure(self, mock_get_first): - with unittest.mock.patch.dict( + with mock.patch.dict( "os.environ", AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**BASE_CONNECTION_KWARGS).get_uri() ): hook = SnowflakeHook() From e56fa90220c955234c05620657d4e3241e6901f2 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Thu, 17 Nov 2022 12:52:55 +0400 Subject: [PATCH 7/7] change `setup` to `setup_method` --- tests/providers/elasticsearch/hooks/test_elasticsearch.py | 2 +- tests/providers/elasticsearch/log/test_es_task_handler.py | 4 ++-- tests/providers/grpc/hooks/test_grpc.py | 2 +- tests/providers/jdbc/hooks/test_jdbc.py | 2 +- tests/providers/qubole/operators/test_qubole_check.py | 2 +- tests/providers/salesforce/hooks/test_salesforce.py | 3 ++- tests/providers/slack/operators/test_slack.py | 6 ++---- 7 files changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/providers/elasticsearch/hooks/test_elasticsearch.py b/tests/providers/elasticsearch/hooks/test_elasticsearch.py index c80ccb890b129..a3bf9bcde26a5 100644 --- a/tests/providers/elasticsearch/hooks/test_elasticsearch.py +++ b/tests/providers/elasticsearch/hooks/test_elasticsearch.py @@ -127,7 +127,7 @@ def search(self, **kwargs): class TestElasticsearchPythonHook: - def setup(self): + def setup_method(self): self.elasticsearch_hook = ElasticsearchPythonHook(hosts=["http://localhost:9200"]) def test_client(self): diff --git a/tests/providers/elasticsearch/log/test_es_task_handler.py b/tests/providers/elasticsearch/log/test_es_task_handler.py index f8a54ab4c99de..bdf274732e555 100644 --- a/tests/providers/elasticsearch/log/test_es_task_handler.py +++ b/tests/providers/elasticsearch/log/test_es_task_handler.py @@ -75,7 +75,7 @@ def ti(self, create_task_instance, create_log_template): clear_db_dags() @elasticmock - def setup(self): + def setup_method(self, method): self.local_log_location = "local/log/location" self.end_of_log_mark = "end_of_log\n" self.write_stdout = False @@ -100,7 +100,7 @@ def setup(self): self.body = {"message": self.test_message, "log_id": self.LOG_ID, "offset": 1} self.es.index(index=self.index_name, doc_type=self.doc_type, body=self.body, id=1) - def teardown(self): + def teardown_method(self): shutil.rmtree(self.local_log_location.split(os.path.sep)[0], ignore_errors=True) def test_client(self): diff --git a/tests/providers/grpc/hooks/test_grpc.py b/tests/providers/grpc/hooks/test_grpc.py index 0813fd4167660..46ae9896e73b1 100644 --- a/tests/providers/grpc/hooks/test_grpc.py +++ b/tests/providers/grpc/hooks/test_grpc.py @@ -61,7 +61,7 @@ def stream_call(self, data): class TestGrpcHook: - def setup(self): + def setup_method(self): self.channel_mock = mock.patch("grpc.Channel").start() def custom_conn_func(self, _): diff --git a/tests/providers/jdbc/hooks/test_jdbc.py b/tests/providers/jdbc/hooks/test_jdbc.py index 6a4739e54b327..50913b023769b 100644 --- a/tests/providers/jdbc/hooks/test_jdbc.py +++ b/tests/providers/jdbc/hooks/test_jdbc.py @@ -32,7 +32,7 @@ class TestJdbcHook: - def setup(self): + def setup_method(self): db.merge_conn( Connection( conn_id="jdbc_default", diff --git a/tests/providers/qubole/operators/test_qubole_check.py b/tests/providers/qubole/operators/test_qubole_check.py index 85b5788fda200..74af6a29d1267 100644 --- a/tests/providers/qubole/operators/test_qubole_check.py +++ b/tests/providers/qubole/operators/test_qubole_check.py @@ -48,7 +48,7 @@ ], ) class TestQuboleCheckMixin: - def setup(self): + def setup_method(self): self.task_id = "test_task" def __construct_operator(self, operator_class, **kwargs): diff --git a/tests/providers/salesforce/hooks/test_salesforce.py b/tests/providers/salesforce/hooks/test_salesforce.py index 4b936c32fe1e1..9906d1c4d1428 100644 --- a/tests/providers/salesforce/hooks/test_salesforce.py +++ b/tests/providers/salesforce/hooks/test_salesforce.py @@ -33,9 +33,10 @@ class TestSalesforceHook: - def setup(self): + def setup_method(self): self.salesforce_hook = SalesforceHook(salesforce_conn_id="conn_id") + @staticmethod def _insert_conn_db_entry(conn_id, conn_object): with create_session() as session: session.query(Connection).filter(Connection.conn_id == conn_id).delete() diff --git a/tests/providers/slack/operators/test_slack.py b/tests/providers/slack/operators/test_slack.py index 08ad3312bb6cb..ef40ac4f6c92d 100644 --- a/tests/providers/slack/operators/test_slack.py +++ b/tests/providers/slack/operators/test_slack.py @@ -74,8 +74,7 @@ def test_hook(self, mock_slack_hook_cls, token, conn_id): class TestSlackAPIPostOperator: - @pytest.fixture(autouse=True) - def setup(self): + def setup_method(self): self.test_username = "test_username" self.test_channel = "#test_slack_channel" self.test_text = "test_text" @@ -184,8 +183,7 @@ def test_api_call_params_with_default_args(self, mock_hook): class TestSlackAPIFileOperator: - @pytest.fixture(autouse=True) - def setup(self): + def setup_method(self): self.test_username = "test_username" self.test_channel = "#test_slack_channel" self.test_initial_comment = "test text file test_filename.txt"