diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index d6bac9fc7fc4e..8a402edebf97b 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -67,6 +67,7 @@ from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState +from airflow.utils.task_instance_session import set_current_task_instance_session if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -649,7 +650,8 @@ def task_render(args, dag: DAG | None = None) -> None: ti, _ = _get_ti( task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory" ) - ti.render_templates() + with create_session() as session, set_current_task_instance_session(session=session): + ti.render_templates() for attr in task.template_fields: print( textwrap.dedent( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 37882dcd78dfb..621a21e53ebb0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -26,7 +26,6 @@ import attr -from airflow import settings from airflow.compat.functools import cache from airflow.exceptions import AirflowException, UnmappableOperator from airflow.models.abstractoperator import ( @@ -54,6 +53,7 @@ from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped from airflow.utils.helpers import is_container, prevent_duplicates +from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY @@ -720,12 +720,13 @@ def render_template_fields( if not jinja_env: jinja_env = self.get_template_env() - # Ideally we'd like to pass in session as an argument to this function, - # but we can't easily change this function signature since operators - # could override this. We can't use @provide_session since it closes and - # expunges everything, which we don't want to do when we are so "deep" - # in the weeds here. We don't close this session for the same reason. - session = settings.Session() + # We retrieve the session here, stored by _run_raw_task in set_current_task_session + # context manager - we cannot pass the session via @provide_session because the signature + # of render_template_fields is defined by BaseOperator and there are already many subclasses + # overriding it, so changing the signature is not an option. However render_template_fields is + # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the + # set_current_task_session context manager to store the session in the current task. + session = get_current_task_instance_session() mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ff1b190a5eab8..82282eb39d984 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -118,6 +118,7 @@ ) from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timeout import timeout from airflow.utils.xcom import XCOM_RETURN_KEY @@ -1511,98 +1512,98 @@ def _run_raw_task( count=0, tags={**self.stats_tags, "state": str(state)}, ) + with set_current_task_instance_session(session=session): + self.task = self.task.prepare_for_execution() + context = self.get_template_context(ignore_param_exceptions=False) - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False) - - try: - if not mark_success: - self._execute_task_with_callbacks(context, test_mode, session=session) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - self.state = TaskInstanceState.SUCCESS - except TaskDeferred as defer: - # The task has signalled it wants to defer execution based on - # a trigger. - self._defer_task(defer=defer, session=session) - self.log.info( - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", - self.dag_id, - self.task_id, - self._date_or_empty("execution_date"), - self._date_or_empty("start_date"), - ) - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self) - session.commit() - return TaskReturnCode.DEFERRED - except AirflowSkipException as e: - # Recording SKIP - # log only if exception has any arguments to prevent log flooding - if e.args: - self.log.info(e) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - self.state = TaskInstanceState.SKIPPED - except AirflowRescheduleException as reschedule_exception: - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) - session.commit() - return None - except (AirflowFailException, AirflowSensorTimeout) as e: - # If AirflowFailException is raised, task should not retry. - # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure(e, test_mode, context, force_fail=True, session=session) - session.commit() - raise - except AirflowException as e: - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - # for case when task is marked as success/failed externally - # or dagrun timed out and task is marked as skipped - # current behavior doesn't hit the callbacks - if self.state in State.finished: - self.clear_next_method_args() - session.merge(self) + try: + if not mark_success: + self._execute_task_with_callbacks(context, test_mode, session=session) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + self.state = TaskInstanceState.SUCCESS + except TaskDeferred as defer: + # The task has signalled it wants to defer execution based on + # a trigger. + self._defer_task(defer=defer, session=session) + self.log.info( + "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", + self.dag_id, + self.task_id, + self._date_or_empty("execution_date"), + self._date_or_empty("start_date"), + ) + if not test_mode: + session.add(Log(self.state, self)) + session.merge(self) + session.commit() + return TaskReturnCode.DEFERRED + except AirflowSkipException as e: + # Recording SKIP + # log only if exception has any arguments to prevent log flooding + if e.args: + self.log.info(e) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + self.state = TaskInstanceState.SKIPPED + except AirflowRescheduleException as reschedule_exception: + self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) session.commit() return None - else: + except (AirflowFailException, AirflowSensorTimeout) as e: + # If AirflowFailException is raised, task should not retry. + # If a sensor in reschedule mode reaches timeout, task should not retry. + self.handle_failure(e, test_mode, context, force_fail=True, session=session) + session.commit() + raise + except AirflowException as e: + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + # for case when task is marked as success/failed externally + # or dagrun timed out and task is marked as skipped + # current behavior doesn't hit the callbacks + if self.state in State.finished: + self.clear_next_method_args() + session.merge(self) + session.commit() + return None + else: + self.handle_failure(e, test_mode, context, session=session) + session.commit() + raise + except (Exception, KeyboardInterrupt) as e: self.handle_failure(e, test_mode, context, session=session) session.commit() raise - except (Exception, KeyboardInterrupt) as e: - self.handle_failure(e, test_mode, context, session=session) - session.commit() - raise - finally: - Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) - - # Recording SKIPPED or SUCCESS - self.clear_next_method_args() - self.end_date = timezone.utcnow() - self._log_state() - self.set_duration() + finally: + Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) + # Same metric with tagging + Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) + + # Recording SKIPPED or SUCCESS + self.clear_next_method_args() + self.end_date = timezone.utcnow() + self._log_state() + self.set_duration() + + # run on_success_callback before db committing + # otherwise, the LocalTaskJob sees the state is changed to `success`, + # but the task_runner is still running, LocalTaskJob then treats the state is set externally! + self._run_finished_callback(self.task.on_success_callback, context, "on_success") - # run on_success_callback before db committing - # otherwise, the LocalTaskJob sees the state is changed to `success`, - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - self._run_finished_callback(self.task.on_success_callback, context, "on_success") - - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self).task = self.task - if self.state == TaskInstanceState.SUCCESS: - self._register_dataset_changes(session=session) + if not test_mode: + session.add(Log(self.state, self)) + session.merge(self).task = self.task + if self.state == TaskInstanceState.SUCCESS: + self._register_dataset_changes(session=session) - session.commit() - if self.state == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session - ) + session.commit() + if self.state == TaskInstanceState.SUCCESS: + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session + ) - return None + return None def _register_dataset_changes(self, *, session: Session) -> None: for obj in self.task.outlets or []: diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py new file mode 100644 index 0000000000000..9d4dd958347c3 --- /dev/null +++ b/airflow/utils/task_instance_session.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import logging +import traceback +from typing import TYPE_CHECKING + +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +__current_task_instance_session: Session | None = None + +log = logging.getLogger(__name__) + + +def get_current_task_instance_session() -> Session: + global __current_task_instance_session + if not __current_task_instance_session: + log.warning("No task session set for this task. Continuing but this likely causes a resource leak.") + log.warning("Please report this and stacktrace below to https://github.com/apache/airflow/issues") + for filename, line_number, name, line in traceback.extract_stack(): + log.warning('File: "%s", %s , in %s', filename, line_number, name) + if line: + log.warning(" %s", line.strip()) + __current_task_instance_session = create_session() + return __current_task_instance_session + + +@contextlib.contextmanager +def set_current_task_instance_session(session: Session): + global __current_task_instance_session + if __current_task_instance_session: + raise RuntimeError( + "Session already set for this task. " + "You can only have one 'set_current_task_session' context manager active at a time." + ) + __current_task_instance_session = session + try: + yield + finally: + __current_task_instance_session = None diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index b24fcb1707811..1368ddb434830 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -36,6 +36,7 @@ from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY @@ -747,36 +748,37 @@ def test_mapped_render_template_fields(dag_maker, session): def fn(arg1, arg2): ... - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output) + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output) - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) ) - ) - session.flush() + session.flush() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) - mapped_ti.map_index = 0 + mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) + mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, BaseOperator) + assert isinstance(mapped_ti.task, MappedOperator) + mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, BaseOperator) - assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}" - assert mapped_ti.task.op_kwargs["arg2"] == "fn" + assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}" + assert mapped_ti.task.op_kwargs["arg2"] == "fn" def test_task_decorator_has_wrapped_attr(): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 51760b85705ce..9336e8a559003 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -39,6 +39,7 @@ from airflow.operators.python import PythonOperator from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE @@ -403,109 +404,114 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect def test_mapped_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template + with set_current_task_instance_session(session=session): + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template def execute(self, context): pass - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - output1 = task1.output - mapped = MyOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + output1 = task1.output + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + mapped_ti.map_index = 0 - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( - "os.path.isfile", return_value=True - ), patch("os.path.getmtime", return_value=0): - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + with set_current_task_instance_session(session=session): + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) -def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template - def execute(self, context): - pass + def execute(self, context): + pass - with dag_maker(session=session): - mapped = MyOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand_kwargs( - [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] - ) + with dag_maker(session=session): + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) - with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( - "os.path.isfile", return_value=True - ), patch("os.path.getmtime", return_value=0): - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" def test_mapped_render_nested_template_fields(dag_maker, session): @@ -607,35 +613,36 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis ], ) def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=2, - keys=None, + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=2, + keys=None, + ) ) - ) - session.flush() - - ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - ti.refresh_from_task(mapped) - ti.map_index = map_index - assert isinstance(ti.task, MappedOperator) - mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(ti.task, MockOperator) - assert ti.task.arg1 == expected - assert ti.task.arg2 == "a" + session.flush() + + ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + ti.refresh_from_task(mapped) + ti.map_index = map_index + assert isinstance(ti.task, MappedOperator) + mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(dag_maker): diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 7a312a4c67c38..12fc108ce0124 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -29,6 +29,7 @@ from airflow.models import Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.operators.bash import BashOperator +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timezone import datetime from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields @@ -153,44 +154,46 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da ], ) def test_delete_old_records( - self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker + self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker, session ): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ - session = settings.Session() - with dag_maker("test_delete_old_records") as dag: - task = BashOperator(task_id="test", bash_command="echo {{ ds }}") - rtif_list = [] - for num in range(rtif_num): - dr = dag_maker.create_dagrun(run_id=str(num), execution_date=dag.start_date + timedelta(days=num)) - ti = dr.task_instances[0] - ti.task = task - rtif_list.append(RTIF(ti)) - - session.add_all(rtif_list) - session.flush() - - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - - for rtif in rtif_list: - assert rtif in result - - assert rtif_num == len(result) - - # Verify old records are deleted and only 'num_to_keep' records are kept - # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records - expected_query_count_based_on_db = ( - expected_query_count + 1 - if session.bind.dialect.name == "mssql" and expected_query_count != 0 - else expected_query_count - ) + with set_current_task_instance_session(session=session): + with dag_maker("test_delete_old_records") as dag: + task = BashOperator(task_id="test", bash_command="echo {{ ds }}") + rtif_list = [] + for num in range(rtif_num): + dr = dag_maker.create_dagrun( + run_id=str(num), execution_date=dag.start_date + timedelta(days=num) + ) + ti = dr.task_instances[0] + ti.task = task + rtif_list.append(RTIF(ti)) + + session.add_all(rtif_list) + session.flush() + + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + + for rtif in rtif_list: + assert rtif in result + + assert rtif_num == len(result) + + # Verify old records are deleted and only 'num_to_keep' records are kept + # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records + expected_query_count_based_on_db = ( + expected_query_count + 1 + if session.bind.dialect.name == "mssql" and expected_query_count != 0 + else expected_query_count + ) - with assert_queries_count(expected_query_count_based_on_db): - RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - assert remaining_rtifs == len(result) + with assert_queries_count(expected_query_count_based_on_db): + RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + assert remaining_rtifs == len(result) @pytest.mark.parametrize( "num_runs, num_to_keep, remaining_rtifs, expected_query_count", @@ -207,40 +210,41 @@ def test_delete_old_records_mapped( Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id with mapped tasks. """ - with dag_maker("test_delete_old_records", session=session) as dag: - mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) - for num in range(num_runs): - dr = dag_maker.create_dagrun( - run_id=f"run_{num}", execution_date=dag.start_date + timedelta(days=num) + with set_current_task_instance_session(session=session): + with dag_maker("test_delete_old_records", session=session) as dag: + mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) + for num in range(num_runs): + dr = dag_maker.create_dagrun( + run_id=f"run_{num}", execution_date=dag.start_date + timedelta(days=num) + ) + + mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) + session.refresh(dr) + for ti in dr.task_instances: + ti.task = dag.get_task(ti.task_id) + session.add(RTIF(ti)) + session.flush() + + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() + assert len(result) == num_runs * 2 + + # Verify old records are deleted and only 'num_to_keep' records are kept + # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records + expected_query_count_based_on_db = ( + expected_query_count + 1 + if session.bind.dialect.name == "mssql" and expected_query_count != 0 + else expected_query_count ) - mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) - session.refresh(dr) - for ti in dr.task_instances: - ti.task = dag.get_task(ti.task_id) - session.add(RTIF(ti)) - session.flush() - - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() - assert len(result) == num_runs * 2 - - # Verify old records are deleted and only 'num_to_keep' records are kept - # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records - expected_query_count_based_on_db = ( - expected_query_count + 1 - if session.bind.dialect.name == "mssql" and expected_query_count != 0 - else expected_query_count - ) - - with assert_queries_count(expected_query_count_based_on_db): - RTIF.delete_old_records( - task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session - ) - result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() - rtif_num_runs = Counter(rtif.run_id for rtif in result) - assert len(rtif_num_runs) == remaining_rtifs - # Check that we have _all_ the data for each row - assert len(result) == remaining_rtifs * 2 + with assert_queries_count(expected_query_count_based_on_db): + RTIF.delete_old_records( + task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session + ) + result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() + rtif_num_runs = Counter(rtif.run_id for rtif in result) + assert len(rtif_num_runs) == remaining_rtifs + # Check that we have _all_ the data for each row + assert len(result) == remaining_rtifs * 2 def test_write(self, dag_maker): """ diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 76d52f4769520..5003010297ad1 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -41,7 +41,7 @@ def pull(value): # The function passed to "map" is *NOT* a task. assert set(dag.task_dict) == {"push", "pull"} - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -79,7 +79,7 @@ def c_to_none(v): pull.expand(value=push().map(c_to_none)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -113,7 +113,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -158,7 +158,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # The "push" task should not fail. decision = dr.task_instance_scheduling_decisions(session=session) @@ -211,7 +211,7 @@ def skip_c(v): collect(value=forward.expand_kwargs(push().map(skip_c))) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -246,7 +246,7 @@ def pull(value): converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) pull.expand_kwargs(converted) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -289,7 +289,7 @@ def convert_zipped(zipped): pull.expand(value=combined.map(convert_zipped)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push_letters" and "push_numbers". decision = dr.task_instance_scheduling_decisions(session=session)