Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions airflow-core/src/airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def set_state(
if not tasks:
return []

task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
task_dags = {
(dag.dag_id if dag else None): dag
for dag in (task[0].dag if isinstance(task, tuple) else task.dag for task in tasks)
}
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
dag = next(iter(task_dags.values()))
if dag is None:
raise ValueError("Received tasks with no DAG")
if not run_id:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _get_upstream_map_indexes(
if (upstream_mapped_group := upstream_task.get_closest_mapped_task_group()) is None:
# regular tasks or non-mapped task groups
map_indexes = None
elif task.get_closest_mapped_task_group() == upstream_mapped_group:
elif task.get_closest_mapped_task_group() is upstream_mapped_group:
# tasks in the same mapped task group hierarchy
map_indexes = ti.map_index
else:
Expand Down
18 changes: 11 additions & 7 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
from airflow.models.serialized_dag import SerializedDagModel
from airflow.sdk.definitions.dag import DAG, _run_task
from airflow.sdk.definitions.param import ParamsDict
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import DagSerialization
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.utils import cli as cli_utils
Expand Down Expand Up @@ -384,7 +385,7 @@ def task_test(args, dag: DAG | None = None) -> None:

if dag:
sdk_dag = dag
scheduler_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
scheduler_dag = DagSerialization.from_dict(DagSerialization.to_dict(dag))
else:
sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id)
scheduler_dag = get_db_dag(args.bundle_name, args.dag_id)
Expand Down Expand Up @@ -429,19 +430,22 @@ def task_test(args, dag: DAG | None = None) -> None:
@providers_configuration_loaded
def task_render(args, dag: DAG | None = None) -> None:
"""Render and displays templated fields for a given task."""
if not dag:
dag = get_bagged_dag(args.bundle_name, args.dag_id)
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
if dag:
sdk_dag = dag
scheduler_dag = DagSerialization.from_dict(DagSerialization.to_dict(dag))
else:
sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id)
scheduler_dag = get_db_dag(args.bundle_name, args.dag_id)
ti, _ = _get_ti(
serialized_dag.get_task(task_id=args.task_id),
scheduler_dag.get_task(task_id=args.task_id),
args.map_index,
logical_date_or_run_id=args.logical_date_or_run_id,
create_if_necessary="memory",
)

with create_session() as session:
context = ti.get_template_context(session=session)
task = dag.get_task(args.task_id)
task = sdk_dag.get_task(args.task_id)
# TODO (GH-52141): After sdk separation, ti.get_template_context() would
# contain serialized operators, but we need the real operators for
# rendering. This does not make sense and eventually we should rewrite
Expand Down
4 changes: 2 additions & 2 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, _send_error_email_notification
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.serialization.serialized_objects import DagSerialization, LazyDeserializedDAG
from airflow.utils.file import iter_airflow_imports
from airflow.utils.state import TaskInstanceState

Expand Down Expand Up @@ -239,7 +239,7 @@ def _serialize_dags(
serialized_dags = []
for dag in bag.dags.values():
try:
data = SerializedDAG.to_dict(dag)
data = DagSerialization.to_dict(dag)
serialized_dags.append(LazyDeserializedDAG(data=data, last_loaded=dag.last_loaded))
except Exception:
log.exception("Failed to serialize DAG: %s", dag.fileloc)
Expand Down
9 changes: 6 additions & 3 deletions airflow-core/src/airflow/models/serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from airflow.models.dagrun import DagRun
from airflow.serialization.dag_dependency import DagDependency
from airflow.serialization.definitions.assets import SerializedAssetUniqueKey as UKey
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.serialization.serialized_objects import DagSerialization
from airflow.settings import COMPRESS_SERIALIZED_DAGS, json
from airflow.utils.hashlib_wrapper import md5
from airflow.utils.session import NEW_SESSION, provide_session
Expand All @@ -56,6 +56,9 @@
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import ColumnElement

from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import LazyDeserializedDAG


log = logging.getLogger(__name__)

Expand Down Expand Up @@ -568,14 +571,14 @@ def data(self) -> dict | None:
@property
def dag(self) -> SerializedDAG:
"""The DAG deserialized from the ``data`` column."""
SerializedDAG._load_operator_extra_links = self.load_op_links
DagSerialization._load_operator_extra_links = self.load_op_links
if isinstance(self.data, dict):
data = self.data
elif isinstance(self.data, str):
data = json.loads(self.data)
else:
raise ValueError("invalid or missing serialized DAG data")
return SerializedDAG.from_dict(data)
return DagSerialization.from_dict(data)

@classmethod
@provide_session
Expand Down
8 changes: 5 additions & 3 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.utils.context import Context

Operator: TypeAlias = MappedOperator | SerializedBaseOperator
Expand Down Expand Up @@ -1480,7 +1481,8 @@ def run(
"""Run TaskInstance (only kept for tests)."""
# This method is only used in ti.run and dag.test and task.test.
# So doing the s10n/de-s10n dance to operator on Serialized task for the scheduler dep check part.
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.serialized_objects import DagSerialization

original_task = self.task
if TYPE_CHECKING:
Expand All @@ -1489,7 +1491,7 @@ def run(

# We don't set up all tests well...
if not isinstance(original_task.dag, SerializedDAG):
serialized_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(original_task.dag))
serialized_dag = DagSerialization.from_dict(DagSerialization.to_dict(original_task.dag))
self.task = serialized_dag.get_task(original_task.task_id)

res = self.check_and_change_state_before_execution(
Expand Down
Loading
Loading