Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
ErrorResponse,
GetConnection,
GetVariable,
OKResponse,
PutVariable,
VariableResult,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
Expand All @@ -53,12 +55,12 @@
from airflow.typing_compat import Self

ToManager = Annotated[
Union["DagFileParsingResult", GetConnection, GetVariable],
Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable],
Field(discriminator="type"),
]

ToDagProcessor = Annotated[
Union["DagFileParseRequest", ConnectionResult, VariableResult, ErrorResponse],
Union["DagFileParseRequest", ConnectionResult, VariableResult, OKResponse, ErrorResponse],
Field(discriminator="type"),
]

Expand Down Expand Up @@ -287,6 +289,9 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: #
resp = var_result.model_dump_json(exclude_unset=True).encode()
else:
resp = var.model_dump_json().encode()
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
resp = OKResponse(ok=True).model_dump_json().encode()
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
28 changes: 26 additions & 2 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session

if TYPE_CHECKING:
from airflow.secrets import BaseSecretsBackend

if TYPE_CHECKING:
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -199,6 +202,25 @@ def set(
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated. Please use `from airflow.sdk import"
"Variable` instead",
DeprecationWarning,
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable

# check if the secret exists in the custom secrets' backend.
# passing the secrets backend initialized on the worker side
Variable.check_for_write_conflict(key=key)
TaskSDKVariable.set(
key=key,
value=value,
description=description,
serialize_json=serialize_json,
)
return
# check if the secret exists in the custom secrets' backend.
Variable.check_for_write_conflict(key=key)
if serialize_json:
Expand Down Expand Up @@ -261,7 +283,7 @@ def rotate_fernet_key(self):
self._val = fernet.rotate(self._val.encode("utf-8")).decode()

@staticmethod
def check_for_write_conflict(key: str) -> None:
def check_for_write_conflict(key: str, secrets_backends: list[BaseSecretsBackend] | None = None) -> None:
"""
Log a warning if a variable exists outside the metastore.

Expand All @@ -271,7 +293,9 @@ def check_for_write_conflict(key: str) -> None:

:param key: Variable Key
"""
for secrets_backend in ensure_secrets_loaded():
if secrets_backends is None:
secrets_backends = ensure_secrets_loaded()
for secrets_backend in secrets_backends:
if not isinstance(secrets_backend, MetastoreBackend):
try:
var_val = secrets_backend.get_variable(key=key)
Expand Down
24 changes: 24 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,30 @@ def dag_in_a_fn():
if result.import_errors:
assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values()))

def test_top_level_variable_set(self, spy_agency: SpyAgency, tmp_path: pathlib.Path):
logger_filehandle = MagicMock()

def dag_in_a_fn():
from airflow.sdk import DAG, Variable

Variable.set("myvar", "123")
with DAG(f"test_{Variable.get('myvar')}"):
...

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)

proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
)

while not proc.is_ready:
proc._service_subprocess(0.1)

result = proc.parsing_result
assert result is not None
assert result.import_errors == {}
assert result.serialized_dags[0].dag_id == "test_123"

def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
logger_filehandle = MagicMock()

Expand Down
6 changes: 6 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False):
if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET:
return default
raise

@classmethod
def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False):
from airflow.sdk.execution_time.context import _set_variable

_set_variable(key, value, description=description, serialize_json=serialize_json)
18 changes: 18 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,24 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
return variable.value


def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
from airflow.sdk.execution_time.comms import ErrorResponse, PutVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if serialize_json:
import json

value = json.dumps(value, indent=2)
else:
value = str(value)

SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)
return


Comment on lines +212 to +229
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats right we need to do that as well

class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand Down
30 changes: 28 additions & 2 deletions task-sdk/tests/task_sdk/definitions/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
# under the License.
from __future__ import annotations

import json
from unittest import mock

import pytest

from airflow.configuration import initialize_secrets_backends
from airflow.sdk import Variable
from airflow.sdk.execution_time.comms import VariableResult
from airflow.sdk.execution_time.comms import PutVariable, VariableResult
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

from tests_common.test_utils.config import conf_vars
Expand All @@ -43,7 +44,7 @@ class TestVariables:
True,
'{"key": "value", "number": 42, "flag": true}',
{"key": "value", "number": 42, "flag": True},
id="deser-object-value",
id="deserialize-object-value",
),
],
)
Expand All @@ -55,6 +56,31 @@ def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_
assert var is not None
assert var == expected_value

@pytest.mark.parametrize(
"deserialize_json, value, expected_value",
[
pytest.param(
False,
"my_value",
"my_value",
id="simple-value",
),
pytest.param(
True,
{"key": "value", "number": 42, "flag": True},
json.dumps({"key": "value", "number": 42, "flag": True}, indent=2),
id="deserialize-object-value",
),
],
)
def test_var_set(self, deserialize_json, value, expected_value, mock_supervisor_comms):
# Act
Variable.set(key="my_key", value=value, serialize_json=deserialize_json)
# Assert
mock_supervisor_comms.send_request.assert_called_once_with(
log=mock.ANY, msg=PutVariable(key="my_key", value=expected_value, description=None)
)


class TestVariableFromSecrets:
def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):
Expand Down