From e5bee2bad7dda0af58291c0342f6d37ba81ca577 Mon Sep 17 00:00:00 2001 From: Amogh Date: Mon, 30 Dec 2024 15:34:55 +0530 Subject: [PATCH] Use existing mock_supervisor_comms fixture for tests --- .../tests/execution_time/test_task_runner.py | 50 ++++++++----------- 1 file changed, 22 insertions(+), 28 deletions(-) diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 7cbe6e649b6c1..582e5a19e1fb3 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -254,7 +254,7 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context, mock_sup ) -def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context): +def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, mock_supervisor_comms): """Test running a basic task that raises a base exception which should send fail_with_retry state.""" from airflow.providers.standard.operators.python import PythonOperator @@ -281,21 +281,18 @@ def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context): instant = timezone.datetime(2024, 12, 3, 10, 0) time_machine.move_to(instant, tick=False) - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as mock_supervisor_comms: - run(ti, log=mock.MagicMock()) + run(ti, log=mock.MagicMock()) - mock_supervisor_comms.send_request.assert_called_once_with( - msg=TaskState( - state=TerminalTIState.FAILED, - end_date=instant, - ), - log=mock.ANY, - ) + mock_supervisor_comms.send_request.assert_called_once_with( + msg=TaskState( + state=TerminalTIState.FAILED, + end_date=instant, + ), + log=mock.ANY, + ) -def test_startup_basic_templated_dag(mocked_parse, make_ti_context): +def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervisor_comms): """Test running a DAG with templated task.""" from airflow.providers.standard.operators.bash import BashOperator @@ -314,22 +311,19 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context): ) mocked_parse(what, "basic_templated_dag", task) - with mock.patch( - "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True - ) as mock_supervisor_comms: - mock_supervisor_comms.get_message.return_value = what - startup() + mock_supervisor_comms.get_message.return_value = what + startup() - mock_supervisor_comms.send_request.assert_called_once_with( - msg=SetRenderedFields( - rendered_fields={ - "bash_command": "echo 'Logical date is {{ logical_date }}'", - "cwd": None, - "env": None, - } - ), - log=mock.ANY, - ) + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SetRenderedFields( + rendered_fields={ + "bash_command": "echo 'Logical date is {{ logical_date }}'", + "cwd": None, + "env": None, + } + ), + log=mock.ANY, + ) @pytest.mark.parametrize(