diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index 089e453d02b43..e6d247918395c 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -379,8 +379,9 @@ class TaskDecoratorCollection: self, *, multiple_outputs: bool | None = None, - use_dill: bool = False, # Added by _DockerDecoratedOperator. python_command: str = "python3", + serializer: Literal["pickle", "cloudpickle", "dill"] | None = None, + use_dill: bool = False, # Added by _DockerDecoratedOperator. # 'command', 'retrieve_output', and 'retrieve_output_path' are filled by # _DockerDecoratedOperator. image: str, @@ -432,8 +433,17 @@ class TaskDecoratorCollection: :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values. Dict will unroll to XCom values with keys as XCom keys. Defaults to False. - :param use_dill: Whether to use dill or pickle for serialization :param python_command: Python command for executing functions, Default: python3 + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. :param image: Docker image from which to create the container. If image tag is omitted, "latest" will be used. :param api_version: Remote API version. Set to ``auto`` to automatically diff --git a/airflow/providers/docker/decorators/docker.py b/airflow/providers/docker/decorators/docker.py index d851c98aca5d9..9812e5fc57488 100644 --- a/airflow/providers/docker/decorators/docker.py +++ b/airflow/providers/docker/decorators/docker.py @@ -18,13 +18,12 @@ import base64 import os -import pickle +import warnings from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Callable, Sequence - -import dill +from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence from airflow.decorators.base import DecoratedOperator, task_decorator_factory +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.docker.operators.docker import DockerOperator from airflow.utils.python_virtualenv import write_python_script @@ -32,6 +31,47 @@ from airflow.decorators.base import TaskDecorator from airflow.utils.context import Context + Serializer = Literal["pickle", "dill", "cloudpickle"] + +try: + from airflow.operators.python import _SERIALIZERS +except ImportError: + import logging + + import lazy_object_proxy + + log = logging.getLogger(__name__) + + def _load_pickle(): + import pickle + + return pickle + + def _load_dill(): + try: + import dill + except ModuleNotFoundError: + log.error("Unable to import `dill` module. Please please make sure that it installed.") + raise + return dill + + def _load_cloudpickle(): + try: + import cloudpickle + except ModuleNotFoundError: + log.error( + "Unable to import `cloudpickle` module. " + "Please install it with: pip install 'apache-airflow[cloudpickle]'" + ) + raise + return cloudpickle + + _SERIALIZERS: dict[Serializer, Any] = { # type: ignore[no-redef] + "pickle": lazy_object_proxy.Proxy(_load_pickle), + "dill": lazy_object_proxy.Proxy(_load_dill), + "cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle), + } + def _generate_decode_command(env_var, file, python_command): # We don't need `f.close()` as the interpreter is about to exit anyway @@ -53,7 +93,6 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator): :param python_callable: A reference to an object that is callable :param python: Python binary name to use - :param use_dill: Whether dill should be used to serialize the callable :param expect_airflow: whether to expect airflow to be installed in the docker environment. if this one is specified, the script to run callable will attempt to load Airflow macros. :param op_kwargs: a dictionary of keyword arguments that will get unpacked @@ -63,6 +102,16 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator): :param multiple_outputs: if set, function return value will be unrolled to multiple XCom values. Dict will unroll to xcom values with keys as keys. Defaults to False. + :param serializer: Which serializer use to serialize the args and result. It can be one of the following: + + - ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library. + - ``"cloudpickle"``: Use cloudpickle for serialize more complex types, + this requires to include cloudpickle in your requirements. + - ``"dill"``: Use dill for serialize more complex types, + this requires to include dill in your requirements. + :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize + the args and result (pickle is default). This allows more complex types + but requires you to include dill in your requirements. """ custom_operator_name = "@task.docker" @@ -74,12 +123,35 @@ def __init__( use_dill=False, python_command="python3", expect_airflow: bool = True, + serializer: Serializer | None = None, **kwargs, ) -> None: + if use_dill: + warnings.warn( + "`use_dill` is deprecated and will be removed in a future version. " + "Please provide serializer='dill' instead.", + AirflowProviderDeprecationWarning, + stacklevel=3, + ) + if serializer: + raise AirflowException( + "Both 'use_dill' and 'serializer' parameters are set. Please set only one of them" + ) + serializer = "dill" + serializer = serializer or "pickle" + if serializer not in _SERIALIZERS: + msg = ( + f"Unsupported serializer {serializer!r}. " + f"Expected one of {', '.join(map(repr, _SERIALIZERS))}" + ) + raise AirflowException(msg) + command = "placeholder command" self.python_command = python_command self.expect_airflow = expect_airflow - self.use_dill = use_dill + self.use_dill = serializer == "dill" + self.serializer: Serializer = serializer + super().__init__( command=command, retrieve_output=True, retrieve_output_path="/tmp/script.out", **kwargs ) @@ -128,9 +200,7 @@ def execute(self, context: Context): @property def pickling_library(self): - if self.use_dill: - return dill - return pickle + return _SERIALIZERS[self.serializer] def docker_task( diff --git a/tests/providers/docker/decorators/test_docker.py b/tests/providers/docker/decorators/test_docker.py index 93db9f211b4db..42b1a514a3a9c 100644 --- a/tests/providers/docker/decorators/test_docker.py +++ b/tests/providers/docker/decorators/test_docker.py @@ -17,12 +17,13 @@ from __future__ import annotations import logging +from importlib.util import find_spec from io import StringIO as StringBuffer import pytest from airflow.decorators import setup, task, teardown -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.models import TaskInstance from airflow.models.dag import DAG from airflow.utils import timezone @@ -32,6 +33,10 @@ DEFAULT_DATE = timezone.datetime(2021, 9, 1) +DILL_INSTALLED = find_spec("dill") is not None +DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not installed") +CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None +CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") class TestDockerDecorator: @@ -207,13 +212,21 @@ def f(): assert teardown_task.is_teardown assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun - @pytest.mark.parametrize("use_dill", [True, False]) - def test_deepcopy_with_python_operator(self, dag_maker, use_dill): + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + pytest.param(None, id="default"), + ], + ) + def test_deepcopy_with_python_operator(self, dag_maker, serializer): import copy from airflow.providers.docker.decorators.docker import _DockerDecoratedOperator - @task.docker(image="python:3.9-slim", auto_remove="force", use_dill=use_dill) + @task.docker(image="python:3.9-slim", auto_remove="force", serializer=serializer) def f(): import logging @@ -247,6 +260,7 @@ def g(): assert isinstance(clone_of_docker_operator, _DockerDecoratedOperator) assert some_task.command == clone_of_docker_operator.command assert some_task.expect_airflow == clone_of_docker_operator.expect_airflow + assert some_task.serializer == clone_of_docker_operator.serializer assert some_task.use_dill == clone_of_docker_operator.use_dill assert some_task.pickling_library is clone_of_docker_operator.pickling_library @@ -317,3 +331,98 @@ def f(): assert 'with open(sys.argv[4], "w") as file:' not in log_content last_line_of_docker_operator_log = log_content.splitlines()[-1] assert "ValueError: This task is expected to fail" in last_line_of_docker_operator_log + + @pytest.mark.parametrize( + "serializer", + [ + pytest.param("pickle", id="pickle"), + pytest.param("dill", marks=DILL_MARKER, id="dill"), + pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"), + ], + ) + def test_ambiguous_serializer(self, dag_maker, serializer): + @task.docker(image="python:3.9-slim", auto_remove="force", use_dill=True, serializer=serializer) + def f(): + pass + + with dag_maker(): + with pytest.warns( + AirflowProviderDeprecationWarning, match="`use_dill` is deprecated and will be removed" + ): + with pytest.raises( + AirflowException, match="Both 'use_dill' and 'serializer' parameters are set" + ): + f() + + def test_invalid_serializer(self, dag_maker): + @task.docker(image="python:3.9-slim", auto_remove="force", serializer="airflow") + def f(): + """Ensure dill is correctly installed.""" + import dill # noqa: F401 + + with dag_maker(): + with pytest.raises(AirflowException, match="Unsupported serializer 'airflow'"): + f() + + @pytest.mark.parametrize( + "serializer", + [ + pytest.param( + "dill", + marks=pytest.mark.skipif( + DILL_INSTALLED, reason="For this test case `dill` shouldn't be installed" + ), + id="dill", + ), + pytest.param( + "cloudpickle", + marks=pytest.mark.skipif( + CLOUDPICKLE_INSTALLED, reason="For this test case `cloudpickle` shouldn't be installed" + ), + id="cloudpickle", + ), + ], + ) + def test_advanced_serializer_not_installed(self, dag_maker, serializer, caplog): + """Test case for check raising an error if dill/cloudpickle is not installed.""" + + @task.docker(image="python:3.9-slim", auto_remove="force", serializer=serializer) + def f(): ... + + with dag_maker(): + with pytest.raises(ModuleNotFoundError): + f() + assert f"Unable to import `{serializer}` module." in caplog.text + + @CLOUDPICKLE_MARKER + def test_add_cloudpickle(self, dag_maker): + @task.docker(image="python:3.9-slim", auto_remove="force", serializer="cloudpickle") + def f(): + """Ensure cloudpickle is correctly installed.""" + import cloudpickle # noqa: F401 + + with dag_maker(): + f() + + @DILL_MARKER + def test_add_dill(self, dag_maker): + @task.docker(image="python:3.9-slim", auto_remove="force", serializer="dill") + def f(): + """Ensure dill is correctly installed.""" + import dill # noqa: F401 + + with dag_maker(): + f() + + @DILL_MARKER + def test_add_dill_use_dill(self, dag_maker): + @task.docker(image="python:3.9-slim", auto_remove="force", use_dill=True) + def f(): + """Ensure dill is correctly installed.""" + import dill # noqa: F401 + + with dag_maker(): + with pytest.warns( + AirflowProviderDeprecationWarning, match="`use_dill` is deprecated and will be removed" + ): + f()