From 0af97c82cf259d2d913d5cfae60cbb159dbd5f22 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 24 Mar 2025 14:17:31 +0800 Subject: [PATCH 1/6] AIP-72: Port Variable.set From TaskSDK to Models --- airflow-core/src/airflow/models/variable.py | 28 +++++++++++++++++-- .../unit/dag_processing/test_processor.py | 27 ++++++++++++++++++ .../src/airflow/sdk/definitions/variable.py | 6 ++++ .../src/airflow/sdk/execution_time/context.py | 18 ++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index e0b9168628566..743571f608f1a 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -36,6 +36,9 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session +if TYPE_CHECKING: + from airflow.secrets import BaseSecretsBackend + if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -199,6 +202,25 @@ def set( :param serialize_json: Serialize the value to a JSON string :param session: Session """ + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + warnings.warn( + "Using Variable.get from `airflow.models` is deprecated. Please use `from airflow.sdk import" + "Variable` instead", + DeprecationWarning, + stacklevel=1, + ) + from airflow.sdk import Variable as TaskSDKVariable + from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND + + # check if the secret exists in the custom secrets' backend. + # passing the secrets backend initialized on the worker side + Variable.check_for_write_conflict(key=key, secrets_backends=SECRETS_BACKEND) + return TaskSDKVariable.set( + key=key, + value=value, + description=description, + serialize_json=serialize_json, + ) # check if the secret exists in the custom secrets' backend. Variable.check_for_write_conflict(key=key) if serialize_json: @@ -261,7 +283,9 @@ def rotate_fernet_key(self): self._val = fernet.rotate(self._val.encode("utf-8")).decode() @staticmethod - def check_for_write_conflict(key: str) -> None: + def check_for_write_conflict( + key: str, secrets_backends: list[BaseSecretsBackend] = ensure_secrets_loaded() + ) -> None: """ Log a warning if a variable exists outside the metastore. @@ -271,7 +295,7 @@ def check_for_write_conflict(key: str) -> None: :param key: Variable Key """ - for secrets_backend in ensure_secrets_loaded(): + for secrets_backend in secrets_backends: if not isinstance(secrets_backend, MetastoreBackend): try: var_val = secrets_backend.get_variable(key=key) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ca9670f81ae4a..7cd48a5485b92 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -181,6 +181,33 @@ def dag_in_a_fn(): if result.import_errors: assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values())) + def test_top_level_variable_set( + self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch + ): + logger_filehandle = MagicMock() + + def dag_in_a_fn(): + from airflow.sdk import DAG, Variable + + Variable.set("myvar", "123") + with DAG(f"test_{Variable.get('myvar')}"): + ... + + path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) + + monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc") + proc = DagFileProcessorProcess.start( + id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + ) + + while not proc.is_ready: + proc._service_subprocess(0.1) + + result = proc.parsing_result + assert result is not None + assert result.import_errors == {} + assert result.serialized_dags[0].dag_id == "test_123" + def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): logger_filehandle = MagicMock() diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 87b0ee29fab50..4f2bd5bc1637c 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -55,3 +55,9 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: return default raise + + @classmethod + def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False): + from airflow.sdk.execution_time.context import _set_variable + + _set_variable(key, value, description=description, serialize_json=serialize_json) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index a9c576f18aea7..c30b37bfef68c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -209,6 +209,24 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: return variable.value +def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: + from airflow.sdk.execution_time.comms import ErrorResponse, PutVariable + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + if serialize_json: + import json + + value = json.dumps(value, indent=2) + else: + value = str(value) + + SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + return + + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" From cb9b17a7169044ddbcb9f1432211b2291c4a8899 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 24 Mar 2025 22:58:57 +0800 Subject: [PATCH 2/6] Refactor secrets_backends paramter to fix test --- airflow-core/src/airflow/models/variable.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 743571f608f1a..8b513870ebe59 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -215,12 +215,13 @@ def set( # check if the secret exists in the custom secrets' backend. # passing the secrets backend initialized on the worker side Variable.check_for_write_conflict(key=key, secrets_backends=SECRETS_BACKEND) - return TaskSDKVariable.set( + TaskSDKVariable.set( key=key, value=value, description=description, serialize_json=serialize_json, ) + return # check if the secret exists in the custom secrets' backend. Variable.check_for_write_conflict(key=key) if serialize_json: @@ -283,9 +284,7 @@ def rotate_fernet_key(self): self._val = fernet.rotate(self._val.encode("utf-8")).decode() @staticmethod - def check_for_write_conflict( - key: str, secrets_backends: list[BaseSecretsBackend] = ensure_secrets_loaded() - ) -> None: + def check_for_write_conflict(key: str, secrets_backends: list[BaseSecretsBackend] | None = None) -> None: """ Log a warning if a variable exists outside the metastore. @@ -295,6 +294,8 @@ def check_for_write_conflict( :param key: Variable Key """ + if secrets_backends is None: + secrets_backends = ensure_secrets_loaded() for secrets_backend in secrets_backends: if not isinstance(secrets_backend, MetastoreBackend): try: From c3019017789716f4b277ed4f385a0e77ce8d7011 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 25 Mar 2025 15:01:22 +0800 Subject: [PATCH 3/6] Fix: DagFileProcesserProcess should handle PutVariable request --- airflow-core/src/airflow/dag_processing/processor.py | 9 +++++++-- airflow-core/tests/unit/dag_processing/test_processor.py | 5 +---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 489b916240eb4..8d6bc5bd3e454 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -38,6 +38,8 @@ ErrorResponse, GetConnection, GetVariable, + OKResponse, + PutVariable, VariableResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess @@ -53,12 +55,12 @@ from airflow.typing_compat import Self ToManager = Annotated[ - Union["DagFileParsingResult", GetConnection, GetVariable], + Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable], Field(discriminator="type"), ] ToDagProcessor = Annotated[ - Union["DagFileParseRequest", ConnectionResult, VariableResult, ErrorResponse], + Union["DagFileParseRequest", ConnectionResult, VariableResult, OKResponse, ErrorResponse], Field(discriminator="type"), ] @@ -287,6 +289,9 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = var_result.model_dump_json(exclude_unset=True).encode() else: resp = var.model_dump_json().encode() + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + resp = OKResponse(ok=True).model_dump_json().encode() else: log.error("Unhandled request", msg=msg) return diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 7cd48a5485b92..e2c41f315af11 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -181,9 +181,7 @@ def dag_in_a_fn(): if result.import_errors: assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values())) - def test_top_level_variable_set( - self, spy_agency: SpyAgency, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch - ): + def test_top_level_variable_set(self, spy_agency: SpyAgency, tmp_path: pathlib.Path): logger_filehandle = MagicMock() def dag_in_a_fn(): @@ -195,7 +193,6 @@ def dag_in_a_fn(): path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) - monkeypatch.setenv("AIRFLOW_VAR_MYVAR", "abc") proc = DagFileProcessorProcess.start( id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle ) From bc076b92c5475c20f31c7fd8a0b67dc9aba4251a Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 25 Mar 2025 15:43:56 +0800 Subject: [PATCH 4/6] Test: airflow.sdk.Variable.set --- .../task_sdk/definitions/test_variables.py | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 6560bdee90381..96a45d0514555 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -17,13 +17,15 @@ # under the License. from __future__ import annotations +import json from unittest import mock import pytest from airflow.configuration import initialize_secrets_backends from airflow.sdk import Variable -from airflow.sdk.execution_time.comms import VariableResult +from airflow.sdk.execution_time.comms import PutVariable, VariableResult +from airflow.sdk.execution_time.supervisor import initialize_secrets_backend_on_workers from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -43,7 +45,7 @@ class TestVariables: True, '{"key": "value", "number": 42, "flag": true}', {"key": "value", "number": 42, "flag": True}, - id="deser-object-value", + id="deserialize-object-value", ), ], ) @@ -55,6 +57,31 @@ def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_ assert var is not None assert var == expected_value + @pytest.mark.parametrize( + "deserialize_json, value, expected_value", + [ + pytest.param( + False, + "my_value", + "my_value", + id="simple-value", + ), + pytest.param( + True, + {"key": "value", "number": 42, "flag": True}, + json.dumps({"key": "value", "number": 42, "flag": True}, indent=2), + id="deserialize-object-value", + ), + ], + ) + def test_var_set(self, deserialize_json, value, expected_value, mock_supervisor_comms): + # Act + Variable.set(key="my_key", value=value, serialize_json=deserialize_json) + # Assert + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, msg=PutVariable(key="my_key", value=expected_value, description=None) + ) + class TestVariableFromSecrets: def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path): From 3ba45df85e6090be5e54d8aca5cc687eaa4fc462 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Wed, 9 Apr 2025 15:50:57 +0800 Subject: [PATCH 5/6] Fix static check after rebasing --- airflow-core/src/airflow/models/variable.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 8b513870ebe59..91b75b5aa40d9 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -210,11 +210,10 @@ def set( stacklevel=1, ) from airflow.sdk import Variable as TaskSDKVariable - from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND # check if the secret exists in the custom secrets' backend. # passing the secrets backend initialized on the worker side - Variable.check_for_write_conflict(key=key, secrets_backends=SECRETS_BACKEND) + Variable.check_for_write_conflict(key=key) TaskSDKVariable.set( key=key, value=value, From 11cf86f72a70e34e3cc74570c071d3b961ff72c9 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Wed, 9 Apr 2025 16:31:58 +0800 Subject: [PATCH 6/6] Fix test_variables static check --- task-sdk/tests/task_sdk/definitions/test_variables.py | 1 - 1 file changed, 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 96a45d0514555..5189ef616ff4f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -25,7 +25,6 @@ from airflow.configuration import initialize_secrets_backends from airflow.sdk import Variable from airflow.sdk.execution_time.comms import PutVariable, VariableResult -from airflow.sdk.execution_time.supervisor import initialize_secrets_backend_on_workers from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars