diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 07a06f683f6fb..dd36d5d31e3e1 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -405,6 +405,11 @@ def __init__( self.dag = dag or DagContext.get_current_dag() + # subdag parameter is only set for SubDagOperator. + # Setting it to None by default as other Operators do not have that field + from airflow.models.dag import DAG + self.subdag: Optional[DAG] = None + self._log = logging.getLogger("airflow.task.operators") # Lineage diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 17bd85920d846..58a2a529a70ce 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -244,7 +244,7 @@ def __init__( self._description = description # set file location to caller source path self.fileloc = sys._getframe().f_back.f_code.co_filename - self.task_dict = dict() # type: Dict[str, BaseOperator] + self.task_dict: Dict[str, BaseOperator] = dict() # set timezone from start_date if start_date and start_date.tzinfo: diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 7df03d9de7171..9e0f946ea83a1 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -159,7 +159,7 @@ def get_dag(self, dag_id, from_file_only=False): # Needs to load from file for a store_serialized_dags dagbag. enforce_from_file = False if self.store_serialized_dags and dag is not None: - from airflow.serialization.serialized_dag import SerializedDAG + from airflow.serialization.serialized_objects import SerializedDAG enforce_from_file = isinstance(dag, SerializedDAG) # If the dag corresponding to root_dag_id is absent or expired diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index f2d8426f275a5..dfc42338fb723 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -29,7 +29,7 @@ from airflow import DAG from airflow.models.base import ID_LEN, Base -from airflow.serialization.serialized_dag import SerializedDAG +from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import json from airflow.utils import db, timezone from airflow.utils.log.logging_mixin import LoggingMixin diff --git a/airflow/serialization/__init__.py b/airflow/serialization/__init__.py new file mode 100644 index 0000000000000..c6d1a147f91f6 --- /dev/null +++ b/airflow/serialization/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""DAG serialization.""" diff --git a/airflow/serialization/serialized_baseoperator.py b/airflow/serialization/serialized_baseoperator.py deleted file mode 100644 index a0464483382e9..0000000000000 --- a/airflow/serialization/serialized_baseoperator.py +++ /dev/null @@ -1,126 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Operator serialization with JSON.""" -from inspect import signature - -from airflow.models.baseoperator import BaseOperator -from airflow.serialization.base_serialization import BaseSerialization - - -class SerializedBaseOperator(BaseOperator, BaseSerialization): - """A JSON serializable representation of operator. - - All operators are casted to SerializedBaseOperator after deserialization. - Class specific attributes used by UI are move to object attributes. - """ - - _decorated_fields = {'executor_config', } - - _CONSTRUCTOR_PARAMS = { - k: v for k, v in signature(BaseOperator).parameters.items() - if v.default is not v.empty and v.default is not None - } - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # task_type is used by UI to display the correct class type, because UI only - # receives BaseOperator from deserialized DAGs. - self._task_type = 'BaseOperator' - # Move class attributes into object attributes. - self.ui_color = BaseOperator.ui_color - self.ui_fgcolor = BaseOperator.ui_fgcolor - self.template_fields = BaseOperator.template_fields - # subdag parameter is only set for SubDagOperator. - # Setting it to None by default as other Operators do not have that field - self.subdag = None - self.operator_extra_links = BaseOperator.operator_extra_links - - @property - def task_type(self) -> str: - # Overwrites task_type of BaseOperator to use _task_type instead of - # __class__.__name__. - return self._task_type - - @task_type.setter - def task_type(self, task_type: str): - self._task_type = task_type - - @classmethod - def serialize_operator(cls, op: BaseOperator) -> dict: - """Serializes operator into a JSON object. - """ - serialize_op = cls.serialize_to_json(op, cls._decorated_fields) - serialize_op['_task_type'] = op.__class__.__name__ - serialize_op['_task_module'] = op.__class__.__module__ - return serialize_op - - @classmethod - def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: - """Deserializes an operator from a JSON object. - """ - from airflow.serialization.serialized_dag import SerializedDAG - from airflow.plugins_manager import operator_extra_links - - op = SerializedBaseOperator(task_id=encoded_op['task_id']) - - # Extra Operator Links - op_extra_links_from_plugin = {} - - for ope in operator_extra_links: - for operator in ope.operators: - if operator.__name__ == encoded_op["_task_type"] and \ - operator.__module__ == encoded_op["_task_module"]: - op_extra_links_from_plugin.update({ope.name: ope}) - - setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) - - for k, v in encoded_op.items(): - - if k == "_downstream_task_ids": - v = set(v) - elif k == "subdag": - v = SerializedDAG.deserialize_dag(v) - elif k in {"retry_delay", "execution_timeout"}: - v = cls._deserialize_timedelta(v) - elif k.endswith("_date"): - v = cls._deserialize_datetime(v) - elif k in cls._decorated_fields or k not in op.get_serialized_fields(): - v = cls._deserialize(v) - # else use v as it is - - setattr(op, k, v) - - for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): - setattr(op, k, None) - - return op - - @classmethod - def _is_excluded(cls, var, attrname, op): - if var is not None and op.has_dag() and attrname.endswith("_date"): - # If this date is the same as the matching field in the dag, then - # don't store it again at the task level. - dag_date = getattr(op.dag, attrname, None) - if var is dag_date or var == dag_date: - return True - if attrname in {"executor_config", "params"} and not var: - # Don't store empty executor config or params dicts. - return True - return super()._is_excluded(var, attrname, op) diff --git a/airflow/serialization/serialized_dag.py b/airflow/serialization/serialized_dag.py deleted file mode 100644 index 0176226981ad6..0000000000000 --- a/airflow/serialization/serialized_dag.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""DAG serialization with JSON.""" -from inspect import signature - -from airflow.models.dag import DAG -from airflow.serialization.base_serialization import BaseSerialization -from airflow.serialization.json_schema import load_dag_schema - - -class SerializedDAG(DAG, BaseSerialization): - """ - A JSON serializable representation of DAG. - - A stringified DAG can only be used in the scope of scheduler and webserver, because fields - that are not serializable, such as functions and customer defined classes, are casted to - strings. - - Compared with SimpleDAG: SerializedDAG contains all information for webserver. - Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are - not pickle-able. SerializedDAG works for all DAGs. - """ - - _decorated_fields = {'schedule_interval', 'default_args'} - - @staticmethod - def __get_constructor_defaults(): # pylint: disable=no-method-argument - param_to_attr = { - 'concurrency': '_concurrency', - 'description': '_description', - 'default_view': '_default_view', - 'access_control': '_access_control', - } - return { - param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items() - if v.default is not v.empty and v.default is not None - } - _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore - del __get_constructor_defaults - - _json_schema = load_dag_schema() - - @classmethod - def serialize_dag(cls, dag: DAG) -> dict: - """Serializes a DAG into a JSON object. - """ - serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields) - - serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] - return serialize_dag - - @classmethod - def deserialize_dag(cls, encoded_dag: dict) -> "SerializedDAG": - """Deserializes a DAG from a JSON object. - """ - from airflow.serialization.serialized_baseoperator import SerializedBaseOperator - - dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) - - for k, v in encoded_dag.items(): - if k == "_downstream_task_ids": - v = set(v) - elif k == "tasks": - v = { - task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v - } - k = "task_dict" - elif k == "timezone": - v = cls._deserialize_timezone(v) - elif k in {"retry_delay", "execution_timeout"}: - v = cls._deserialize_timedelta(v) - elif k.endswith("_date"): - v = cls._deserialize_datetime(v) - elif k in cls._decorated_fields: - v = cls._deserialize(v) - # else use v as it is - - setattr(dag, k, v) - - keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() - for k in keys_to_set_none: - setattr(dag, k, None) - - setattr(dag, 'full_filepath', dag.fileloc) - for task in dag.task_dict.values(): - task.dag = dag - serializable_task: SerializedBaseOperator = task - - for date_attr in ["start_date", "end_date"]: - if getattr(serializable_task, date_attr) is None: - setattr(serializable_task, date_attr, getattr(dag, date_attr)) - - if serializable_task.subdag is not None: - setattr(serializable_task.subdag, 'parent_dag', dag) - serializable_task.subdag.is_subdag = True - - for task_id in serializable_task.downstream_task_ids: - # Bypass set_upstream etc here - it does more than we want - # noinspection PyProtectedMember - dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access - - return dag - - @classmethod - def to_dict(cls, var) -> dict: - """Stringifies DAGs and operators contained by var and returns a dict of var. - """ - json_dict = { - "__version": cls.SERIALIZER_VERSION, - "dag": cls.serialize_dag(var) - } - - # Validate Serialized DAG with Json Schema. Raises Error if it mismatches - cls.validate_schema(json_dict) - return json_dict - - @classmethod - def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG': - """Deserializes a python dict in to the DAG and operators it contains.""" - ver = serialized_obj.get('__version', '') - if ver != cls.SERIALIZER_VERSION: - raise ValueError("Unsure how to deserialize version {!r}".format(ver)) - return cls.deserialize_dag(serialized_obj['dag']) diff --git a/airflow/serialization/base_serialization.py b/airflow/serialization/serialized_objects.py similarity index 55% rename from airflow/serialization/base_serialization.py rename to airflow/serialization/serialized_objects.py index 56850fc6ea671..5f07021288ef7 100644 --- a/airflow/serialization/base_serialization.py +++ b/airflow/serialization/serialized_objects.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,32 +15,24 @@ # specific language governing permissions and limitations # under the License. -"""Utils for DAG serialization with JSON.""" - +"""Serialized DAG and BaseOperator""" import datetime import enum import logging -from inspect import Parameter +from inspect import Parameter, signature from typing import Dict, Optional, Set, Union import pendulum from dateutil import relativedelta -from airflow.exceptions import AirflowException -from airflow.models import DAG +from airflow import DAG, AirflowException, LoggingMixin +from airflow.models import Connection from airflow.models.baseoperator import BaseOperator -from airflow.models.connection import Connection from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding -from airflow.serialization.json_schema import Validator +from airflow.serialization.json_schema import Validator, load_dag_schema from airflow.settings import json -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.www.utils import get_python_source -LOG = LoggingMixin().log - -# Serialization failure returns 'failed'. -FAILED = 'serialization_failed' - class BaseSerialization: """BaseSerialization provides utils for serialization.""" @@ -58,9 +48,9 @@ class BaseSerialization: # FIXME: not needed if _included_fields of DAG and operator are customized. _excluded_types = (logging.Logger, Connection, type) - _json_schema = None # type: Optional[Validator] + _json_schema: Optional[Validator] = None - _CONSTRUCTOR_PARAMS = {} # type: Dict[str, Parameter] + _CONSTRUCTOR_PARAMS: Dict[str, Parameter] = {} SERIALIZER_VERSION = 1 @@ -154,8 +144,6 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements (3) Operator has a special field CLASS to record the original class name for displaying in UI. """ - from airflow.serialization.serialized_dag import SerializedDAG - from airflow.serialization.serialized_baseoperator import SerializedBaseOperator try: if cls._is_primitive(var): # enum.IntEnum is an int instance, it causes json dumps error so we use its value. @@ -207,8 +195,6 @@ def _serialize(cls, var): # pylint: disable=too-many-return-statements @classmethod def _deserialize(cls, encoded_var): # pylint: disable=too-many-return-statements """Helper function of depth first search for deserialization.""" - from airflow.serialization.serialized_dag import SerializedDAG - from airflow.serialization.serialized_baseoperator import SerializedBaseOperator # JSON primitives (except for dict) are not encoded. if cls._is_primitive(encoded_var): return encoded_var @@ -265,3 +251,218 @@ def _value_is_hardcoded_default(cls, attrname, value): if attrname in cls._CONSTRUCTOR_PARAMS and cls._CONSTRUCTOR_PARAMS[attrname].default is value: return True return False + + +class SerializedBaseOperator(BaseOperator, BaseSerialization): + """A JSON serializable representation of operator. + + All operators are casted to SerializedBaseOperator after deserialization. + Class specific attributes used by UI are move to object attributes. + """ + + _decorated_fields = {'executor_config', } + + _CONSTRUCTOR_PARAMS = { + k: v for k, v in signature(BaseOperator).parameters.items() + if v.default is not v.empty and v.default is not None + } + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # task_type is used by UI to display the correct class type, because UI only + # receives BaseOperator from deserialized DAGs. + self._task_type = 'BaseOperator' + # Move class attributes into object attributes. + self.ui_color = BaseOperator.ui_color + self.ui_fgcolor = BaseOperator.ui_fgcolor + self.template_fields = BaseOperator.template_fields + self.operator_extra_links = BaseOperator.operator_extra_links + + @property + def task_type(self) -> str: + # Overwrites task_type of BaseOperator to use _task_type instead of + # __class__.__name__. + return self._task_type + + @task_type.setter + def task_type(self, task_type: str): + self._task_type = task_type + + @classmethod + def serialize_operator(cls, op: BaseOperator) -> dict: + """Serializes operator into a JSON object. + """ + serialize_op = cls.serialize_to_json(op, cls._decorated_fields) + serialize_op['_task_type'] = op.__class__.__name__ + serialize_op['_task_module'] = op.__class__.__module__ + return serialize_op + + @classmethod + def deserialize_operator(cls, encoded_op: dict) -> BaseOperator: + """Deserializes an operator from a JSON object. + """ + from airflow.plugins_manager import operator_extra_links + + op = SerializedBaseOperator(task_id=encoded_op['task_id']) + + # Extra Operator Links + op_extra_links_from_plugin = {} + + for ope in operator_extra_links: + for operator in ope.operators: + if operator.__name__ == encoded_op["_task_type"] and \ + operator.__module__ == encoded_op["_task_module"]: + op_extra_links_from_plugin.update({ope.name: ope}) + + setattr(op, "operator_extra_links", list(op_extra_links_from_plugin.values())) + + for k, v in encoded_op.items(): + + if k == "_downstream_task_ids": + v = set(v) + elif k == "subdag": + v = SerializedDAG.deserialize_dag(v) + elif k in {"retry_delay", "execution_timeout"}: + v = cls._deserialize_timedelta(v) + elif k.endswith("_date"): + v = cls._deserialize_datetime(v) + elif k in cls._decorated_fields or k not in op.get_serialized_fields(): + v = cls._deserialize(v) + # else use v as it is + + setattr(op, k, v) + + for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys(): + setattr(op, k, None) + + return op + + @classmethod + def _is_excluded(cls, var, attrname, op): + if var is not None and op.has_dag() and attrname.endswith("_date"): + # If this date is the same as the matching field in the dag, then + # don't store it again at the task level. + dag_date = getattr(op.dag, attrname, None) + if var is dag_date or var == dag_date: + return True + if attrname in {"executor_config", "params"} and not var: + # Don't store empty executor config or params dicts. + return True + return super()._is_excluded(var, attrname, op) + + +class SerializedDAG(DAG, BaseSerialization): + """ + A JSON serializable representation of DAG. + + A stringified DAG can only be used in the scope of scheduler and webserver, because fields + that are not serializable, such as functions and customer defined classes, are casted to + strings. + + Compared with SimpleDAG: SerializedDAG contains all information for webserver. + Compared with DagPickle: DagPickle contains all information for worker, but some DAGs are + not pickle-able. SerializedDAG works for all DAGs. + """ + + _decorated_fields = {'schedule_interval', 'default_args'} + + @staticmethod + def __get_constructor_defaults(): # pylint: disable=no-method-argument + param_to_attr = { + 'concurrency': '_concurrency', + 'description': '_description', + 'default_view': '_default_view', + 'access_control': '_access_control', + } + return { + param_to_attr.get(k, k): v for k, v in signature(DAG).parameters.items() + if v.default is not v.empty and v.default is not None + } + _CONSTRUCTOR_PARAMS = __get_constructor_defaults.__func__() # type: ignore + del __get_constructor_defaults + + _json_schema = load_dag_schema() + + @classmethod + def serialize_dag(cls, dag: DAG) -> dict: + """Serializes a DAG into a JSON object. + """ + serialize_dag = cls.serialize_to_json(dag, cls._decorated_fields) + + serialize_dag["tasks"] = [cls._serialize(task) for _, task in dag.task_dict.items()] + return serialize_dag + + @classmethod + def deserialize_dag(cls, encoded_dag: dict) -> 'SerializedDAG': + """Deserializes a DAG from a JSON object. + """ + dag = SerializedDAG(dag_id=encoded_dag['_dag_id']) + + for k, v in encoded_dag.items(): + if k == "_downstream_task_ids": + v = set(v) + elif k == "tasks": + v = { + task["task_id"]: SerializedBaseOperator.deserialize_operator(task) for task in v + } + k = "task_dict" + elif k == "timezone": + v = cls._deserialize_timezone(v) + elif k in {"retry_delay", "execution_timeout"}: + v = cls._deserialize_timedelta(v) + elif k.endswith("_date"): + v = cls._deserialize_datetime(v) + elif k in cls._decorated_fields: + v = cls._deserialize(v) + # else use v as it is + + setattr(dag, k, v) + + keys_to_set_none = dag.get_serialized_fields() - encoded_dag.keys() - cls._CONSTRUCTOR_PARAMS.keys() + for k in keys_to_set_none: + setattr(dag, k, None) + + setattr(dag, 'full_filepath', dag.fileloc) + for task in dag.task_dict.values(): + task.dag = dag + serializable_task: BaseOperator = task + + for date_attr in ["start_date", "end_date"]: + if getattr(serializable_task, date_attr) is None: + setattr(serializable_task, date_attr, getattr(dag, date_attr)) + + if serializable_task.subdag is not None: + setattr(serializable_task.subdag, 'parent_dag', dag) + serializable_task.subdag.is_subdag = True + + for task_id in serializable_task.downstream_task_ids: + # Bypass set_upstream etc here - it does more than we want + # noinspection PyProtectedMember + dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access + + return dag + + @classmethod + def to_dict(cls, var) -> dict: + """Stringifies DAGs and operators contained by var and returns a dict of var. + """ + json_dict = { + "__version": cls.SERIALIZER_VERSION, + "dag": cls.serialize_dag(var) + } + + # Validate Serialized DAG with Json Schema. Raises Error if it mismatches + cls.validate_schema(json_dict) + return json_dict + + @classmethod + def from_dict(cls, serialized_obj: dict) -> 'SerializedDAG': + """Deserializes a python dict in to the DAG and operators it contains.""" + ver = serialized_obj.get('__version', '') + if ver != cls.SERIALIZER_VERSION: + raise ValueError("Unsure how to deserialize version {!r}".format(ver)) + return cls.deserialize_dag(serialized_obj['dag']) + + +LOG = LoggingMixin().log +FAILED = 'serialization_failed' diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index 009e91960c922..e58d7a48075f4 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -24,7 +24,7 @@ from airflow import example_dags as example_dags_module from airflow.models import DagBag from airflow.models.serialized_dag import SerializedDagModel as SDM -from airflow.serialization.serialized_dag import SerializedDAG +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import db diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index df005a3a18b0f..89d4fb18fcbd6 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -35,8 +35,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.subdag_operator import SubDagOperator -from airflow.serialization.serialized_baseoperator import SerializedBaseOperator -from airflow.serialization.serialized_dag import SerializedDAG +from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG from airflow.utils.tests import CustomBaseOperator, GoogleLink serialized_simple_dag_ground_truth = {