Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@
import json
from typing import TYPE_CHECKING

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import DagNotFound, DagRunAlreadyExists
from airflow.models import DagBag, DagModel, DagRun
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

if TYPE_CHECKING:
from datetime import datetime

from sqlalchemy.orm.session import Session


def _trigger_dag(
dag_id: str,
Expand Down Expand Up @@ -103,12 +107,15 @@ def _trigger_dag(
return dag_runs


@internal_api_call
@provide_session
def trigger_dag(
dag_id: str,
run_id: str | None = None,
conf: dict | str | None = None,
execution_date: datetime | None = None,
replace_microseconds: bool = True,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
Triggers execution of DAG specified by dag_id.
Expand All @@ -118,6 +125,7 @@ def trigger_dag(
:param conf: configuration
:param execution_date: date of execution
:param replace_microseconds: whether microseconds should be zeroed
:param session: Unused. Only added in compatibility with database isolation mode
:return: first dag run triggered - even if more than one Dag Runs were triggered or None
"""
dag_model = DagModel.get_current(dag_id)
Expand Down
2 changes: 2 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

@functools.lru_cache
def initialize_method_map() -> dict[str, Callable]:
from airflow.api.common.trigger_dag import trigger_dag
from airflow.cli.commands.task_command import _get_ti_db_access
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
Expand Down Expand Up @@ -92,6 +93,7 @@ def initialize_method_map() -> dict[str, Callable]:
_add_log,
_xcom_pull,
_record_task_map_for_downstreams,
trigger_dag,
DagCode.remove_deleted_code,
DagModel.deactivate_deleted_dags,
DagModel.get_paused_dag_ids,
Expand Down
22 changes: 22 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,28 @@ def __init__(self, dag_run: DagRun, execution_date: datetime.datetime, run_id: s
f"A DAG Run already exists for DAG {dag_run.dag_id} at {execution_date} with run id {run_id}"
)
self.dag_run = dag_run
self.execution_date = execution_date
self.run_id = run_id

def serialize(self):
cls = self.__class__
# Note the DagRun object will be detached here and fails serialization, we need to create a new one
from airflow.models import DagRun

dag_run = DagRun(
state=self.dag_run.state,
dag_id=self.dag_run.dag_id,
run_id=self.dag_run.run_id,
external_trigger=self.dag_run.external_trigger,
run_type=self.dag_run.run_type,
execution_date=self.dag_run.execution_date,
)
dag_run.id = self.dag_run.id
return (
f"{cls.__module__}.{cls.__name__}",
(),
{"dag_run": dag_run, "execution_date": self.execution_date, "run_id": self.run_id},
)


class DagFileExists(AirflowBadRequest):
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
TaskInstanceKey,
clear_task_instances,
)
from airflow.models.tasklog import LogTemplate
from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.security import permissions
from airflow.settings import json
Expand Down Expand Up @@ -338,6 +339,9 @@ def _create_orm_dagrun(
creating_job_id=creating_job_id,
data_interval=data_interval,
)
# Load defaults into the following two fields to ensure result can be serialized detached
run.log_template_id = int(session.scalar(select(func.max(LogTemplate.__table__.c.id))))
run.consumed_dataset_events = []
session.add(run)
session.flush()
run.dag = dag
Expand Down
14 changes: 14 additions & 0 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_internal.internal_api_call import InternalApiConfig
from airflow.configuration import conf
from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -83,6 +84,8 @@ class TriggerDagRunOperator(BaseOperator):
"""
Triggers a DAG run for a specified DAG ID.

Note that if database isolation mode is enabled, not all features are supported.

:param trigger_dag_id: The ``dag_id`` of the DAG to trigger (templated).
:param trigger_run_id: The run ID to use for the triggered DAG run (templated).
If not provided, a run ID will be automatically generated.
Expand Down Expand Up @@ -174,6 +177,14 @@ def __init__(
self.logical_date = logical_date

def execute(self, context: Context):
if InternalApiConfig.get_use_internal_api():
if self.reset_dag_run:
raise AirflowException("Parameter reset_dag_run=True is broken with Database Isolation Mode.")
if self.wait_for_completion:
raise AirflowException(
"Parameter wait_for_completion=True is broken with Database Isolation Mode."
)

if isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date
elif isinstance(self.logical_date, str):
Expand Down Expand Up @@ -210,6 +221,7 @@ def execute(self, context: Context):
if dag_model is None:
raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel")

# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
Expand Down Expand Up @@ -250,6 +262,7 @@ def execute(self, context: Context):
)
time.sleep(self.poke_interval)

# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run.refresh_from_db()
state = dag_run.state
if state in self.failed_states:
Expand All @@ -263,6 +276,7 @@ def execute_complete(self, context: Context, session: Session, event: tuple[str,
# This logical_date is parsed from the return trigger event
provided_logical_date = event[1]["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs structural changes for AIP-72
dag_run = session.execute(
select(DagRun).where(
DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date
Expand Down
7 changes: 6 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,12 @@ def get_custom_dep() -> list[DagDependency]:

@classmethod
def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
if var is not None and op.has_dag() and attrname.endswith("_date"):
if (
var is not None
and op.has_dag()
and op.dag.__class__ is not AttributeRemoved
and attrname.endswith("_date")
):
# If this date is the same as the matching field in the dag, then
# don't store it again at the task level.
dag_date = getattr(op.dag, attrname, None)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3293,7 +3293,7 @@ def test_count_number_queries(self, tasks_count):
dag = DAG("test_dagrun_query_count", start_date=DEFAULT_DATE)
for i in range(tasks_count):
EmptyOperator(task_id=f"dummy_task_{i}", owner="test", dag=dag)
with assert_queries_count(2):
with assert_queries_count(3):
dag.create_dagrun(
run_id="test_dagrun_query_count",
state=State.RUNNING,
Expand Down
Loading