From 08a797427d4716fc92a6fee27efc8d4143eb7b73 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Fri, 11 Aug 2023 08:51:15 +0200 Subject: [PATCH] Reuse _run_task_session in mapped render_template_fields The `render_template_fields` method of mapped operator needs to use database session object to render mapped fields, but it cannot get the session passed by @provide_session decorator, because it is used in derived classes and we cannot change the signature without impacting those classes. So far it was done by creating new session in mapped_operator, but it has the drawback of creating an extra session while one is already created (remnder_template_fields is always run in the context of task run and it always has a session created already in _run_raw_task). It also causes problems in our tests where two opened database session accessed database at the same time and it cases sqlite exception on concurrent access and mysql error on running operations out of sync - likely when the same object was modified in both sessions. This PR changes the approach - rather than creating a new session in the mapped_operator, we are retrieving the session from one stored by the _run_raw_task. It is done by context manager and adequate protection has been added to make sure that: a) the call is made within the context manager b) context manageer is never initialized twice in the same call stack After this change, resources used by running task will be smaller, and mapped tasks will not always open 2 DB sesions. Fixes: #33178 --- airflow/cli/commands/task_command.py | 4 +- airflow/models/mappedoperator.py | 15 +- airflow/models/taskinstance.py | 167 +++++++++--------- airflow/utils/task_instance_session.py | 60 +++++++ tests/decorators/test_python.py | 52 +++--- tests/models/test_mappedoperator.py | 233 +++++++++++++------------ tests/models/test_renderedtifields.py | 132 +++++++------- tests/models/test_xcom_arg_map.py | 14 +- 8 files changed, 377 insertions(+), 300 deletions(-) create mode 100644 airflow/utils/task_instance_session.py diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index d6bac9fc7fc4e..8a402edebf97b 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -67,6 +67,7 @@ from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState +from airflow.utils.task_instance_session import set_current_task_instance_session if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -649,7 +650,8 @@ def task_render(args, dag: DAG | None = None) -> None: ti, _ = _get_ti( task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory" ) - ti.render_templates() + with create_session() as session, set_current_task_instance_session(session=session): + ti.render_templates() for attr in task.template_fields: print( textwrap.dedent( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 37882dcd78dfb..621a21e53ebb0 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -26,7 +26,6 @@ import attr -from airflow import settings from airflow.compat.functools import cache from airflow.exceptions import AirflowException, UnmappableOperator from airflow.models.abstractoperator import ( @@ -54,6 +53,7 @@ from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped from airflow.utils.helpers import is_container, prevent_duplicates +from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY @@ -720,12 +720,13 @@ def render_template_fields( if not jinja_env: jinja_env = self.get_template_env() - # Ideally we'd like to pass in session as an argument to this function, - # but we can't easily change this function signature since operators - # could override this. We can't use @provide_session since it closes and - # expunges everything, which we don't want to do when we are so "deep" - # in the weeds here. We don't close this session for the same reason. - session = settings.Session() + # We retrieve the session here, stored by _run_raw_task in set_current_task_session + # context manager - we cannot pass the session via @provide_session because the signature + # of render_template_fields is defined by BaseOperator and there are already many subclasses + # overriding it, so changing the signature is not an option. However render_template_fields is + # always executed within "_run_raw_task" so we make sure that _run_raw_task uses the + # set_current_task_session context manager to store the session in the current task. + session = get_current_task_instance_session() mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index ff1b190a5eab8..82282eb39d984 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -118,6 +118,7 @@ ) from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timeout import timeout from airflow.utils.xcom import XCOM_RETURN_KEY @@ -1511,98 +1512,98 @@ def _run_raw_task( count=0, tags={**self.stats_tags, "state": str(state)}, ) + with set_current_task_instance_session(session=session): + self.task = self.task.prepare_for_execution() + context = self.get_template_context(ignore_param_exceptions=False) - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False) - - try: - if not mark_success: - self._execute_task_with_callbacks(context, test_mode, session=session) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - self.state = TaskInstanceState.SUCCESS - except TaskDeferred as defer: - # The task has signalled it wants to defer execution based on - # a trigger. - self._defer_task(defer=defer, session=session) - self.log.info( - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", - self.dag_id, - self.task_id, - self._date_or_empty("execution_date"), - self._date_or_empty("start_date"), - ) - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self) - session.commit() - return TaskReturnCode.DEFERRED - except AirflowSkipException as e: - # Recording SKIP - # log only if exception has any arguments to prevent log flooding - if e.args: - self.log.info(e) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - self.state = TaskInstanceState.SKIPPED - except AirflowRescheduleException as reschedule_exception: - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) - session.commit() - return None - except (AirflowFailException, AirflowSensorTimeout) as e: - # If AirflowFailException is raised, task should not retry. - # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure(e, test_mode, context, force_fail=True, session=session) - session.commit() - raise - except AirflowException as e: - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - # for case when task is marked as success/failed externally - # or dagrun timed out and task is marked as skipped - # current behavior doesn't hit the callbacks - if self.state in State.finished: - self.clear_next_method_args() - session.merge(self) + try: + if not mark_success: + self._execute_task_with_callbacks(context, test_mode, session=session) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + self.state = TaskInstanceState.SUCCESS + except TaskDeferred as defer: + # The task has signalled it wants to defer execution based on + # a trigger. + self._defer_task(defer=defer, session=session) + self.log.info( + "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", + self.dag_id, + self.task_id, + self._date_or_empty("execution_date"), + self._date_or_empty("start_date"), + ) + if not test_mode: + session.add(Log(self.state, self)) + session.merge(self) + session.commit() + return TaskReturnCode.DEFERRED + except AirflowSkipException as e: + # Recording SKIP + # log only if exception has any arguments to prevent log flooding + if e.args: + self.log.info(e) + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + self.state = TaskInstanceState.SKIPPED + except AirflowRescheduleException as reschedule_exception: + self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) session.commit() return None - else: + except (AirflowFailException, AirflowSensorTimeout) as e: + # If AirflowFailException is raised, task should not retry. + # If a sensor in reschedule mode reaches timeout, task should not retry. + self.handle_failure(e, test_mode, context, force_fail=True, session=session) + session.commit() + raise + except AirflowException as e: + if not test_mode: + self.refresh_from_db(lock_for_update=True, session=session) + # for case when task is marked as success/failed externally + # or dagrun timed out and task is marked as skipped + # current behavior doesn't hit the callbacks + if self.state in State.finished: + self.clear_next_method_args() + session.merge(self) + session.commit() + return None + else: + self.handle_failure(e, test_mode, context, session=session) + session.commit() + raise + except (Exception, KeyboardInterrupt) as e: self.handle_failure(e, test_mode, context, session=session) session.commit() raise - except (Exception, KeyboardInterrupt) as e: - self.handle_failure(e, test_mode, context, session=session) - session.commit() - raise - finally: - Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) - - # Recording SKIPPED or SUCCESS - self.clear_next_method_args() - self.end_date = timezone.utcnow() - self._log_state() - self.set_duration() + finally: + Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) + # Same metric with tagging + Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) + + # Recording SKIPPED or SUCCESS + self.clear_next_method_args() + self.end_date = timezone.utcnow() + self._log_state() + self.set_duration() + + # run on_success_callback before db committing + # otherwise, the LocalTaskJob sees the state is changed to `success`, + # but the task_runner is still running, LocalTaskJob then treats the state is set externally! + self._run_finished_callback(self.task.on_success_callback, context, "on_success") - # run on_success_callback before db committing - # otherwise, the LocalTaskJob sees the state is changed to `success`, - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - self._run_finished_callback(self.task.on_success_callback, context, "on_success") - - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self).task = self.task - if self.state == TaskInstanceState.SUCCESS: - self._register_dataset_changes(session=session) + if not test_mode: + session.add(Log(self.state, self)) + session.merge(self).task = self.task + if self.state == TaskInstanceState.SUCCESS: + self._register_dataset_changes(session=session) - session.commit() - if self.state == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session - ) + session.commit() + if self.state == TaskInstanceState.SUCCESS: + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session + ) - return None + return None def _register_dataset_changes(self, *, session: Session) -> None: for obj in self.task.outlets or []: diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py new file mode 100644 index 0000000000000..9d4dd958347c3 --- /dev/null +++ b/airflow/utils/task_instance_session.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import contextlib +import logging +import traceback +from typing import TYPE_CHECKING + +from airflow.utils.session import create_session + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + +__current_task_instance_session: Session | None = None + +log = logging.getLogger(__name__) + + +def get_current_task_instance_session() -> Session: + global __current_task_instance_session + if not __current_task_instance_session: + log.warning("No task session set for this task. Continuing but this likely causes a resource leak.") + log.warning("Please report this and stacktrace below to https://github.com/apache/airflow/issues") + for filename, line_number, name, line in traceback.extract_stack(): + log.warning('File: "%s", %s , in %s', filename, line_number, name) + if line: + log.warning(" %s", line.strip()) + __current_task_instance_session = create_session() + return __current_task_instance_session + + +@contextlib.contextmanager +def set_current_task_instance_session(session: Session): + global __current_task_instance_session + if __current_task_instance_session: + raise RuntimeError( + "Session already set for this task. " + "You can only have one 'set_current_task_session' context manager active at a time." + ) + __current_task_instance_session = session + try: + yield + finally: + __current_task_instance_session = None diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index b24fcb1707811..1368ddb434830 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -36,6 +36,7 @@ from airflow.utils import timezone from airflow.utils.state import State from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY @@ -747,36 +748,37 @@ def test_mapped_render_template_fields(dag_maker, session): def fn(arg1, arg2): ... - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output) + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output) - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) ) - ) - session.flush() + session.flush() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) - mapped_ti.map_index = 0 + mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) + mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, BaseOperator) + assert isinstance(mapped_ti.task, MappedOperator) + mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, BaseOperator) - assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}" - assert mapped_ti.task.op_kwargs["arg2"] == "fn" + assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}" + assert mapped_ti.task.op_kwargs["arg2"] == "fn" def test_task_decorator_has_wrapped_attr(): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 51760b85705ce..9336e8a559003 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -39,6 +39,7 @@ from airflow.operators.python import PythonOperator from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE @@ -403,109 +404,114 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect def test_mapped_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template + with set_current_task_instance_session(session=session): + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template def execute(self, context): pass - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - output1 = task1.output - mapped = MyOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + output1 = task1.output + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + mapped_ti.map_index = 0 - ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=1, - keys=None, - ) - ) - session.flush() + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - mapped_ti.map_index = 0 - assert isinstance(mapped_ti.task, MappedOperator) - with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( - "os.path.isfile", return_value=True - ), patch("os.path.getmtime", return_value=0): - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + with set_current_task_instance_session(session=session): + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) -def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session): - class MyOperator(BaseOperator): - template_fields = ("partial_template", "map_template", "file_template") - template_ext = (".ext",) - - def __init__( - self, partial_template, partial_static, map_template, map_static, file_template, **kwargs - ): - for value in [partial_template, partial_static, map_template, map_static, file_template]: - assert isinstance(value, str), "value should have been resolved before unmapping" - super().__init__(**kwargs) - self.partial_template = partial_template - self.partial_static = partial_static - self.map_template = map_template - self.map_static = map_static - self.file_template = file_template + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template - def execute(self, context): - pass + def execute(self, context): + pass - with dag_maker(session=session): - mapped = MyOperator.partial( - task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).expand_kwargs( - [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] - ) + with dag_maker(session=session): + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun() - mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) - assert isinstance(mapped_ti.task, MappedOperator) - with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( - "os.path.isfile", return_value=True - ), patch("os.path.getmtime", return_value=0): - mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(mapped_ti.task, MyOperator) + assert isinstance(mapped_ti.task, MappedOperator) + with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), patch( + "os.path.isfile", return_value=True + ), patch("os.path.getmtime", return_value=0): + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) - assert mapped_ti.task.partial_template == "a", "Should be templated!" - assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" - assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" - assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" - assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" def test_mapped_render_nested_template_fields(dag_maker, session): @@ -607,35 +613,36 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis ], ) def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) - - dr = dag_maker.create_dagrun() - ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) - - ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) - - session.add( - TaskMap( - dag_id=dr.dag_id, - task_id=task1.task_id, - run_id=dr.run_id, - map_index=-1, - length=2, - keys=None, + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=2, + keys=None, + ) ) - ) - session.flush() - - ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) - ti.refresh_from_task(mapped) - ti.map_index = map_index - assert isinstance(ti.task, MappedOperator) - mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(ti.task, MockOperator) - assert ti.task.arg1 == expected - assert ti.task.arg2 == "a" + session.flush() + + ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + ti.refresh_from_task(mapped) + ti.map_index = map_index + assert isinstance(ti.task, MappedOperator) + mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(dag_maker): diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 7a312a4c67c38..12fc108ce0124 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -29,6 +29,7 @@ from airflow.models import Variable from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF from airflow.operators.bash import BashOperator +from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.timezone import datetime from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields @@ -153,44 +154,46 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da ], ) def test_delete_old_records( - self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker + self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker, session ): """ Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ - session = settings.Session() - with dag_maker("test_delete_old_records") as dag: - task = BashOperator(task_id="test", bash_command="echo {{ ds }}") - rtif_list = [] - for num in range(rtif_num): - dr = dag_maker.create_dagrun(run_id=str(num), execution_date=dag.start_date + timedelta(days=num)) - ti = dr.task_instances[0] - ti.task = task - rtif_list.append(RTIF(ti)) - - session.add_all(rtif_list) - session.flush() - - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - - for rtif in rtif_list: - assert rtif in result - - assert rtif_num == len(result) - - # Verify old records are deleted and only 'num_to_keep' records are kept - # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records - expected_query_count_based_on_db = ( - expected_query_count + 1 - if session.bind.dialect.name == "mssql" and expected_query_count != 0 - else expected_query_count - ) + with set_current_task_instance_session(session=session): + with dag_maker("test_delete_old_records") as dag: + task = BashOperator(task_id="test", bash_command="echo {{ ds }}") + rtif_list = [] + for num in range(rtif_num): + dr = dag_maker.create_dagrun( + run_id=str(num), execution_date=dag.start_date + timedelta(days=num) + ) + ti = dr.task_instances[0] + ti.task = task + rtif_list.append(RTIF(ti)) + + session.add_all(rtif_list) + session.flush() + + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + + for rtif in rtif_list: + assert rtif in result + + assert rtif_num == len(result) + + # Verify old records are deleted and only 'num_to_keep' records are kept + # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records + expected_query_count_based_on_db = ( + expected_query_count + 1 + if session.bind.dialect.name == "mssql" and expected_query_count != 0 + else expected_query_count + ) - with assert_queries_count(expected_query_count_based_on_db): - RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - assert remaining_rtifs == len(result) + with assert_queries_count(expected_query_count_based_on_db): + RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + assert remaining_rtifs == len(result) @pytest.mark.parametrize( "num_runs, num_to_keep, remaining_rtifs, expected_query_count", @@ -207,40 +210,41 @@ def test_delete_old_records_mapped( Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id with mapped tasks. """ - with dag_maker("test_delete_old_records", session=session) as dag: - mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) - for num in range(num_runs): - dr = dag_maker.create_dagrun( - run_id=f"run_{num}", execution_date=dag.start_date + timedelta(days=num) + with set_current_task_instance_session(session=session): + with dag_maker("test_delete_old_records", session=session) as dag: + mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) + for num in range(num_runs): + dr = dag_maker.create_dagrun( + run_id=f"run_{num}", execution_date=dag.start_date + timedelta(days=num) + ) + + mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) + session.refresh(dr) + for ti in dr.task_instances: + ti.task = dag.get_task(ti.task_id) + session.add(RTIF(ti)) + session.flush() + + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() + assert len(result) == num_runs * 2 + + # Verify old records are deleted and only 'num_to_keep' records are kept + # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records + expected_query_count_based_on_db = ( + expected_query_count + 1 + if session.bind.dialect.name == "mssql" and expected_query_count != 0 + else expected_query_count ) - mapped.expand_mapped_task(dr.run_id, session=dag_maker.session) - session.refresh(dr) - for ti in dr.task_instances: - ti.task = dag.get_task(ti.task_id) - session.add(RTIF(ti)) - session.flush() - - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() - assert len(result) == num_runs * 2 - - # Verify old records are deleted and only 'num_to_keep' records are kept - # For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records - expected_query_count_based_on_db = ( - expected_query_count + 1 - if session.bind.dialect.name == "mssql" and expected_query_count != 0 - else expected_query_count - ) - - with assert_queries_count(expected_query_count_based_on_db): - RTIF.delete_old_records( - task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session - ) - result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() - rtif_num_runs = Counter(rtif.run_id for rtif in result) - assert len(rtif_num_runs) == remaining_rtifs - # Check that we have _all_ the data for each row - assert len(result) == remaining_rtifs * 2 + with assert_queries_count(expected_query_count_based_on_db): + RTIF.delete_old_records( + task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session + ) + result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() + rtif_num_runs = Counter(rtif.run_id for rtif in result) + assert len(rtif_num_runs) == remaining_rtifs + # Check that we have _all_ the data for each row + assert len(result) == remaining_rtifs * 2 def test_write(self, dag_maker): """ diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py index 76d52f4769520..5003010297ad1 100644 --- a/tests/models/test_xcom_arg_map.py +++ b/tests/models/test_xcom_arg_map.py @@ -41,7 +41,7 @@ def pull(value): # The function passed to "map" is *NOT* a task. assert set(dag.task_dict) == {"push", "pull"} - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -79,7 +79,7 @@ def c_to_none(v): pull.expand(value=push().map(c_to_none)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -113,7 +113,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -158,7 +158,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # The "push" task should not fail. decision = dr.task_instance_scheduling_decisions(session=session) @@ -211,7 +211,7 @@ def skip_c(v): collect(value=forward.expand_kwargs(push().map(skip_c))) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -246,7 +246,7 @@ def pull(value): converted = push().map(lambda v: v * 2).map(lambda v: {"value": v}) pull.expand_kwargs(converted) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push". decision = dr.task_instance_scheduling_decisions(session=session) @@ -289,7 +289,7 @@ def convert_zipped(zipped): pull.expand(value=combined.map(convert_zipped)) - dr = dag_maker.create_dagrun() + dr = dag_maker.create_dagrun(session=session) # Run "push_letters" and "push_numbers". decision = dr.task_instance_scheduling_decisions(session=session)