diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 5cd50a3165b7e..765a94712ca0e 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -17,7 +17,7 @@ # under the License. import warnings -from typing import TYPE_CHECKING, Iterable, Union +from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone @@ -26,6 +26,7 @@ from airflow.utils.state import State if TYPE_CHECKING: + from pendulum import DateTime from sqlalchemy import Session from airflow.models import DagRun @@ -66,9 +67,9 @@ def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator def skip( self, dag_run: "DagRun", - execution_date: "timezone.DateTime", - tasks: "Iterable[BaseOperator]", - session: "Session" = None, + execution_date: "DateTime", + tasks: Sequence["BaseOperator"], + session: "Session", ): """ Sets tasks instances to skipped from the same dag run. @@ -114,11 +115,7 @@ def skip( session.commit() # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. - try: - task_id = self.task_id - except AttributeError: - task_id = None - + task_id: Optional[str] = getattr(self, "task_id", None) if task_id is not None: from airflow.models.xcom import XCom diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index 99c2b9aca5b2b..4bb9689e7dda6 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -16,10 +16,11 @@ # specific language governing permissions and limitations # under the License. +import datetime import json import logging import pickle -from typing import Any, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, Union, cast, overload import pendulum from sqlalchemy import Column, LargeBinary, String @@ -79,14 +80,60 @@ def init_on_load(self): def __repr__(self): return f'' + @overload @classmethod - @provide_session - def set(cls, key, value, task_id, dag_id, execution_date=None, run_id=None, session=None): + def set( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + session: Optional[Session] = None, + ) -> None: + """Store an XCom value. + + A deprecated form of this function accepts ``execution_date`` instead of + ``run_id``. The two arguments are mutually exclusive. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param session: Database session. If not given, a new session will be + created for this function. + :type session: sqlalchemy.orm.session.Session """ - Store an XCom value. - :return: None - """ + @overload + @classmethod + def set( + cls, + key: str, + value: Any, + task_id: str, + dag_id: str, + execution_date: datetime.datetime, + session: Optional[Session] = None, + ) -> None: + """:sphinx-autoapi-skip:""" + + @classmethod + @provide_session + def set( + cls, + key: str, + value: Any, + task_id: str, + dag_id: str, + execution_date: Optional[datetime.datetime] = None, + session: Session = None, + *, + run_id: Optional[str] = None, + ) -> None: + """:sphinx-autoapi-skip:""" if not (execution_date is None) ^ (run_id is None): raise ValueError("Exactly one of execution_date or run_id must be passed") @@ -94,70 +141,95 @@ def set(cls, key, value, task_id, dag_id, execution_date=None, run_id=None, sess from airflow.models.dagrun import DagRun dag_run = session.query(DagRun).filter_by(dag_id=dag_id, run_id=run_id).one() - execution_date = dag_run.execution_date - value = XCom.serialize_value(value) - - # remove any duplicate XComs + # Remove duplicate XComs and insert a new one. session.query(cls).filter( - cls.key == key, cls.execution_date == execution_date, cls.task_id == task_id, cls.dag_id == dag_id + cls.key == key, + cls.execution_date == execution_date, + cls.task_id == task_id, + cls.dag_id == dag_id, ).delete() - + new = cast(Any, cls)( # Work around Mypy complaining model not defining '__init__'. + key=key, + value=cls.serialize_value(value), + execution_date=execution_date, + task_id=task_id, + dag_id=dag_id, + ) + session.add(new) session.flush() - # insert new XCom - session.add(XCom(key=key, value=value, execution_date=execution_date, task_id=task_id, dag_id=dag_id)) + @overload + @classmethod + def get_one( + cls, + *, + run_id: str, + key: Optional[str] = None, + task_id: Optional[str] = None, + dag_id: Optional[str] = None, + include_prior_dates: bool = False, + session: Optional[Session] = None, + ) -> Optional[Any]: + """Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. + + A deprecated form of this function accepts ``execution_date`` instead of + ``run_id``. The two arguments are mutually exclusive. + + :param run_id: DAG run ID for the task. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param dag_id: Only pull XCom from this DAG. If *None* (default), the + DAG of the calling task is used. + :param include_prior_dates: If *False* (default), only XCom from the + specified DAG run is returned. If *True*, the latest matching XCom is + returned regardless of the run it belongs to. + :param session: Database session. If not given, a new session will be + created for this function. + :type session: sqlalchemy.orm.session.Session + """ - session.flush() + @overload + @classmethod + def get_one( + cls, + execution_date: pendulum.DateTime, + key: Optional[str] = None, + task_id: Optional[str] = None, + dag_id: Optional[str] = None, + include_prior_dates: bool = False, + session: Optional[Session] = None, + ) -> Optional[Any]: + """:sphinx-autoapi-skip:""" @classmethod @provide_session def get_one( cls, execution_date: Optional[pendulum.DateTime] = None, - run_id: Optional[str] = None, key: Optional[str] = None, task_id: Optional[Union[str, Iterable[str]]] = None, dag_id: Optional[Union[str, Iterable[str]]] = None, include_prior_dates: bool = False, session: Session = None, + *, + run_id: Optional[str] = None, ) -> Optional[Any]: - """ - Retrieve an XCom value, optionally meeting certain criteria. Returns None - of there are no results. - - ``run_id`` and ``execution_date`` are mutually exclusive. - - This method returns "full" XCom values (i.e. it uses ``deserialize_value`` from the XCom backend). - Please use :meth:`get_many` if you want the "shortened" value via ``orm_deserialize_value`` - - :param execution_date: Execution date for the task - :type execution_date: pendulum.datetime - :param run_id: Dag run id for the task - :type run_id: str - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. To remove the filter, pass key=None. - :type key: str - :param task_id: Only XComs from task with matching id will be - pulled. Can pass None to remove the filter. - :type task_id: str - :param dag_id: If provided, only pulls XCom from this DAG. - If None (default), the DAG of the calling task is used. - :type dag_id: str - :param include_prior_dates: If False, only XCom from the current - execution_date are returned. If True, XCom from previous dates - are returned as well. - :type include_prior_dates: bool - :param session: database session - :type session: sqlalchemy.orm.session.Session - """ + """:sphinx-autoapi-skip:""" if not (execution_date is None) ^ (run_id is None): raise ValueError("Exactly one of execution_date or run_id must be passed") - result = ( - cls.get_many( - execution_date=execution_date, + if run_id is not None: + query = cls.get_many( run_id=run_id, key=key, task_ids=task_id, @@ -165,58 +237,88 @@ def get_one( include_prior_dates=include_prior_dates, session=session, ) - .with_entities(cls.value) - .first() - ) + elif execution_date is not None: + query = cls.get_many( + execution_date=execution_date, + key=key, + task_ids=task_id, + dag_ids=dag_id, + include_prior_dates=include_prior_dates, + session=session, + ) + else: + raise RuntimeError("Should not happen?") + + result = query.with_entities(cls.value).first() if result: return cls.deserialize_value(result) return None + @overload + @classmethod + def get_many( + cls, + *, + run_id: str, + key: Optional[str] = None, + task_ids: Union[str, Iterable[str], None] = None, + dag_ids: Union[str, Iterable[str], None] = None, + include_prior_dates: bool = False, + limit: Optional[int] = None, + session: Optional[Session] = None, + ) -> Query: + """Composes a query to get one or more XCom entries. + + This function returns an SQLAlchemy query of full XCom objects. If you + just want one stored value, use :meth:`get_one` instead. + + A deprecated form of this function accepts ``execution_date`` instead of + ``run_id``. The two arguments are mutually exclusive. + + :param run_id: DAG run ID for the task. + :param key: A key for the XComs. If provided, only XComs with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param task_ids: Only XComs from task with matching IDs will be pulled. + Pass *None* (default) to remove the filter. + :param dag_id: Only pulls XComs from this DAG. If *None* (default), the + DAG of the calling task is used. + :param include_prior_dates: If *False* (default), only XComs from the + specified DAG run are returned. If *True*, all matching XComs are + returned regardless of the run it belongs to. + :param session: Database session. If not given, a new session will be + created for this function. + :type session: sqlalchemy.orm.session.Session + """ + + @overload + @classmethod + def get_many( + cls, + execution_date: pendulum.DateTime, + key: Optional[str] = None, + task_ids: Union[str, Iterable[str], None] = None, + dag_ids: Union[str, Iterable[str], None] = None, + include_prior_dates: bool = False, + limit: Optional[int] = None, + session: Optional[Session] = None, + ) -> Query: + """:sphinx-autoapi-skip:""" + @classmethod @provide_session def get_many( cls, execution_date: Optional[pendulum.DateTime] = None, - run_id: Optional[str] = None, key: Optional[str] = None, task_ids: Optional[Union[str, Iterable[str]]] = None, dag_ids: Optional[Union[str, Iterable[str]]] = None, include_prior_dates: bool = False, limit: Optional[int] = None, session: Session = None, + *, + run_id: Optional[str] = None, ) -> Query: - """ - Composes a query to get one or more values from the xcom table. - - ``run_id`` and ``execution_date`` are mutually exclusive. - - This function returns an SQLAlchemy query of full XCom objects. If you just want one stored value then - use :meth:`get_one`. - - :param execution_date: Execution date for the task - :type execution_date: pendulum.datetime - :param run_id: Dag run id for the task - :type run_id: str - :param key: A key for the XCom. If provided, only XComs with matching - keys will be returned. To remove the filter, pass key=None. - :type key: str - :param task_ids: Only XComs from tasks with matching ids will be - pulled. Can pass None to remove the filter. - :type task_ids: str or iterable of strings (representing task_ids) - :param dag_ids: If provided, only pulls XComs from this DAG. - If None (default), the DAG of the calling task is used. - :type dag_ids: str - :param include_prior_dates: If False, only XComs from the current - execution_date are returned. If True, XComs from previous dates - are returned as well. - :type include_prior_dates: bool - :param limit: If required, limit the number of returned objects. - XCom objects can be quite big and you might want to limit the - number of rows. - :type limit: int - :param session: database session - :type session: sqlalchemy.orm.session.Session - """ + """:sphinx-autoapi-skip:""" if not (execution_date is None) ^ (run_id is None): raise ValueError("Exactly one of execution_date or run_id must be passed") @@ -262,8 +364,8 @@ def get_many( @classmethod @provide_session - def delete(cls, xcoms, session=None): - """Delete Xcom""" + def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> None: + """Delete one or multiple XCom entries.""" if isinstance(xcoms, XCom): xcoms = [xcoms] for xcom in xcoms: @@ -272,37 +374,49 @@ def delete(cls, xcoms, session=None): session.delete(xcom) session.commit() + @overload + @classmethod + def clear(cls, *, dag_id: str, task_id: str, run_id: str, session: Optional[Session] = None) -> None: + """Clear all XCom data from the database for the given task instance. + + A deprecated form of this function accepts ``execution_date`` instead of + ``run_id``. The two arguments are mutually exclusive. + + :param dag_id: ID of DAG to clear the XCom for. + :param task_id: ID of task to clear the XCom for. + :param run_id: ID of DAG run to clear the XCom for. + :param session: Database session. If not given, a new session will be + created for this function. + :type session: sqlalchemy.orm.session.Session + """ + + @overload + @classmethod + def clear( + cls, + execution_date: pendulum.DateTime, + dag_id: str, + task_id: str, + session: Optional[Session] = None, + ) -> None: + """:sphinx-autoapi-skip:""" + @classmethod @provide_session def clear( cls, execution_date: Optional[pendulum.DateTime] = None, - dag_id: str = None, - task_id: str = None, - run_id: str = None, + dag_id: Optional[str] = None, + task_id: Optional[str] = None, + run_id: Optional[str] = None, session: Session = None, ) -> None: - """ - Clears all XCom data from the database for the task instance - - ``run_id`` and ``execution_date`` are mutually exclusive. - - :param execution_date: Execution date for the task - :type execution_date: pendulum.datetime or None - :param dag_id: ID of DAG to clear the XCom for. - :type dag_id: str - :param task_id: Only XComs from task with matching id will be cleared. - :type task_id: str - :param run_id: Dag run id for the task - :type run_id: str or None - :param session: database session - :type session: sqlalchemy.orm.session.Session - """ + """:sphinx-autoapi-skip:""" # Given the historic order of this function (execution_date was first argument) to add a new optional # param we need to add default values for everything :( - if not dag_id: + if dag_id is None: raise TypeError("clear() missing required argument: dag_id") - if not task_id: + if task_id is None: raise TypeError("clear() missing required argument: task_id") if not (execution_date is None) ^ (run_id is None): @@ -364,7 +478,7 @@ def orm_deserialize_value(self) -> Any: return BaseXCom.deserialize_value(self) -def resolve_xcom_backend(): +def resolve_xcom_backend() -> Type[BaseXCom]: """Resolves custom XCom class""" clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}") if clazz: @@ -376,4 +490,7 @@ def resolve_xcom_backend(): return BaseXCom -XCom = resolve_xcom_backend() +if TYPE_CHECKING: + XCom = BaseXCom # Hack to avoid Mypy "Variable 'XCom' is not valid as a type". +else: + XCom = resolve_xcom_backend() diff --git a/docs/exts/docs_build/run_patched_sphinx.py b/docs/exts/docs_build/run_patched_sphinx.py index d93a84fccf517..e9bc2090c9c4f 100755 --- a/docs/exts/docs_build/run_patched_sphinx.py +++ b/docs/exts/docs_build/run_patched_sphinx.py @@ -16,6 +16,12 @@ # specific language governing permissions and limitations # under the License. +"""Hacks to patch up Sphinx-AutoAPI before running Sphinx. + +Unfortunately we have a problem updating to a newer version of Sphinx-AutoAPI, +and have to use v1.0.0, so monkeypatching is used as the last resort. +""" + import os import sys @@ -29,9 +35,26 @@ default_file_mapping, default_ignore_patterns, ) +from autoapi.mappers.python.objects import PythonPythonMapper from sphinx.cmd.build import main +def new_python_python_mapper_display_getter(self: PythonPythonMapper) -> bool: + """Patched getter to apply our special skip magic. + + If the docstring is exactly ``:sphinx-autoapi-skip:``, don't display this. + """ + if ":sphinx-autoapi-skip:" in self.docstring.split(): + return False + return old_python_python_mapper_property.__get__(self, PythonPythonMapper) + + +# HACK: sphinx-autoapi 1.0.0 is way too old to understand various modern Python +# magic such as typing.overload, so we apply magic to tell it when to skip. +old_python_python_mapper_property = PythonPythonMapper.display +PythonPythonMapper.display = property(new_python_python_mapper_display_getter) + + def run_autoapi(app): """Load AutoAPI data from the filesystem.""" if not app.config.autoapi_dirs: @@ -95,10 +118,8 @@ def run_autoapi(app): sphinx_mapper_obj.output_rst(root=normalized_root, source_suffix=out_suffix) -# HACK: sphinx-auto map did not correctly use the confdir attribute instead of srcdir when specifying the -# directory to contain the generated files. -# Unfortunately we have a problem updating to a newer version of this library and we have to use -# sphinx-autoapi v1.0.0, so I am monkeypatching this library to fix this one problem. +# HACK: sphinx-autoapi does not correctly use the confdir attribute instead of +# srcdir when specifying the directory to contain the generated files. autoapi.extension.run_autoapi = run_autoapi sys.exit(main(sys.argv[1:]))