diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 489b916240eb4..8d6bc5bd3e454 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -38,6 +38,8 @@ ErrorResponse, GetConnection, GetVariable, + OKResponse, + PutVariable, VariableResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess @@ -53,12 +55,12 @@ from airflow.typing_compat import Self ToManager = Annotated[ - Union["DagFileParsingResult", GetConnection, GetVariable], + Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable], Field(discriminator="type"), ] ToDagProcessor = Annotated[ - Union["DagFileParseRequest", ConnectionResult, VariableResult, ErrorResponse], + Union["DagFileParseRequest", ConnectionResult, VariableResult, OKResponse, ErrorResponse], Field(discriminator="type"), ] @@ -287,6 +289,9 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # resp = var_result.model_dump_json(exclude_unset=True).encode() else: resp = var.model_dump_json().encode() + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) + resp = OKResponse(ok=True).model_dump_json().encode() else: log.error("Unhandled request", msg=msg) return diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index e0b9168628566..91b75b5aa40d9 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -36,6 +36,9 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session +if TYPE_CHECKING: + from airflow.secrets import BaseSecretsBackend + if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -199,6 +202,25 @@ def set( :param serialize_json: Serialize the value to a JSON string :param session: Session """ + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + warnings.warn( + "Using Variable.get from `airflow.models` is deprecated. Please use `from airflow.sdk import" + "Variable` instead", + DeprecationWarning, + stacklevel=1, + ) + from airflow.sdk import Variable as TaskSDKVariable + + # check if the secret exists in the custom secrets' backend. + # passing the secrets backend initialized on the worker side + Variable.check_for_write_conflict(key=key) + TaskSDKVariable.set( + key=key, + value=value, + description=description, + serialize_json=serialize_json, + ) + return # check if the secret exists in the custom secrets' backend. Variable.check_for_write_conflict(key=key) if serialize_json: @@ -261,7 +283,7 @@ def rotate_fernet_key(self): self._val = fernet.rotate(self._val.encode("utf-8")).decode() @staticmethod - def check_for_write_conflict(key: str) -> None: + def check_for_write_conflict(key: str, secrets_backends: list[BaseSecretsBackend] | None = None) -> None: """ Log a warning if a variable exists outside the metastore. @@ -271,7 +293,9 @@ def check_for_write_conflict(key: str) -> None: :param key: Variable Key """ - for secrets_backend in ensure_secrets_loaded(): + if secrets_backends is None: + secrets_backends = ensure_secrets_loaded() + for secrets_backend in secrets_backends: if not isinstance(secrets_backend, MetastoreBackend): try: var_val = secrets_backend.get_variable(key=key) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ca9670f81ae4a..e2c41f315af11 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -181,6 +181,30 @@ def dag_in_a_fn(): if result.import_errors: assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values())) + def test_top_level_variable_set(self, spy_agency: SpyAgency, tmp_path: pathlib.Path): + logger_filehandle = MagicMock() + + def dag_in_a_fn(): + from airflow.sdk import DAG, Variable + + Variable.set("myvar", "123") + with DAG(f"test_{Variable.get('myvar')}"): + ... + + path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) + + proc = DagFileProcessorProcess.start( + id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + ) + + while not proc.is_ready: + proc._service_subprocess(0.1) + + result = proc.parsing_result + assert result is not None + assert result.import_errors == {} + assert result.serialized_dags[0].dag_id == "test_123" + def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): logger_filehandle = MagicMock() diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 87b0ee29fab50..4f2bd5bc1637c 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -55,3 +55,9 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: return default raise + + @classmethod + def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False): + from airflow.sdk.execution_time.context import _set_variable + + _set_variable(key, value, description=description, serialize_json=serialize_json) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index a9c576f18aea7..c30b37bfef68c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -209,6 +209,24 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: return variable.value +def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: + from airflow.sdk.execution_time.comms import ErrorResponse, PutVariable + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + if serialize_json: + import json + + value = json.dumps(value, indent=2) + else: + value = str(value) + + SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + return + + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 6560bdee90381..5189ef616ff4f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -17,13 +17,14 @@ # under the License. from __future__ import annotations +import json from unittest import mock import pytest from airflow.configuration import initialize_secrets_backends from airflow.sdk import Variable -from airflow.sdk.execution_time.comms import VariableResult +from airflow.sdk.execution_time.comms import PutVariable, VariableResult from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -43,7 +44,7 @@ class TestVariables: True, '{"key": "value", "number": 42, "flag": true}', {"key": "value", "number": 42, "flag": True}, - id="deser-object-value", + id="deserialize-object-value", ), ], ) @@ -55,6 +56,31 @@ def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_ assert var is not None assert var == expected_value + @pytest.mark.parametrize( + "deserialize_json, value, expected_value", + [ + pytest.param( + False, + "my_value", + "my_value", + id="simple-value", + ), + pytest.param( + True, + {"key": "value", "number": 42, "flag": True}, + json.dumps({"key": "value", "number": 42, "flag": True}, indent=2), + id="deserialize-object-value", + ), + ], + ) + def test_var_set(self, deserialize_json, value, expected_value, mock_supervisor_comms): + # Act + Variable.set(key="my_key", value=value, serialize_json=deserialize_json) + # Assert + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, msg=PutVariable(key="my_key", value=expected_value, description=None) + ) + class TestVariableFromSecrets: def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):