Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ def setdefault(cls, key, default, description=None, deserialize_json=False):
:param description: Default value to set Description of the Variable
:param deserialize_json: Store this as a JSON encoded value in the DB
and un-encode it when retrieving a value
:param session: Session
:return: Mixed
"""
obj = Variable.get(key, default_var=None, deserialize_json=deserialize_json)
if obj is None:
if default is not None:
Variable.set(key, default, description=description, serialize_json=deserialize_json)
Variable.set(key=key, value=default, description=description, serialize_json=deserialize_json)
return default
else:
raise ValueError("Default Value must be set")
Expand Down Expand Up @@ -170,9 +171,10 @@ def set(
:param value: Value to set for the Variable
:param description: Description of the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
# check if the secret exists in the custom secrets' backend.
Variable.check_for_write_conflict(key)
Variable.check_for_write_conflict(key=key)
if serialize_json:
stored_value = json.dumps(value, indent=2)
else:
Expand Down Expand Up @@ -201,16 +203,19 @@ def update(
:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
:param session: Session
"""
Variable.check_for_write_conflict(key)
Variable.check_for_write_conflict(key=key)

if Variable.get_variable_from_secrets(key=key) is None:
raise KeyError(f"Variable {key} does not exist")
obj = session.scalar(select(Variable).where(Variable.key == key))
if obj is None:
raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")

Variable.set(key, value, description=obj.description, serialize_json=serialize_json)
Variable.set(
key=key, value=value, description=obj.description, serialize_json=serialize_json, session=session
)

@staticmethod
@provide_session
Expand Down
40 changes: 22 additions & 18 deletions tests/models/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setup_test_cases(self):
db.clear_db_variables()
crypto._fernet = None

@conf_vars({("core", "fernet_key"): ""})
@conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): "True"})
def test_variable_no_encryption(self, session):
"""
Test variables without encryption
Expand Down Expand Up @@ -100,12 +100,13 @@ def test_variable_set_get_round_trip(self):
Variable.set("tested_var_set_id", "Monday morning breakfast")
assert "Monday morning breakfast" == Variable.get("tested_var_set_id")

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_variable_set_with_env_variable(self, caplog, session):
caplog.set_level(logging.WARNING, logger=variable.log.name)
Variable.set(key="key", value="db-value", session=session)
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"):
# setting value while shadowed by an env variable will generate a warning
Variable.set("key", "new-db-value")
Variable.set(key="key", value="new-db-value", session=session)
# value set above is not returned because the env variable value takes priority
assert "env-value" == Variable.get("key")
# invalidate the cache to re-evaluate value
Expand All @@ -120,6 +121,7 @@ def test_variable_set_with_env_variable(self, caplog, session):
"EnvironmentVariablesBackend"
)

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
@mock.patch("airflow.models.variable.ensure_secrets_loaded")
def test_variable_set_with_extra_secret_backend(self, mock_ensure_secrets, caplog, session):
caplog.set_level(logging.WARNING, logger=variable.log.name)
Expand All @@ -137,11 +139,11 @@ def test_variable_set_with_extra_secret_backend(self, mock_ensure_secrets, caplo
"will be updated, but to read it you have to delete the conflicting variable from "
"MockSecretsBackend"
)
Variable.delete("key")
Variable.delete(key="key", session=session)

def test_variable_set_get_round_trip_json(self):
value = {"a": 17, "b": 47}
Variable.set("tested_var_set_id", value, serialize_json=True)
Variable.set(key="tested_var_set_id", value=value, serialize_json=True)
assert value == Variable.get("tested_var_set_id", deserialize_json=True)

def test_variable_update(self, session):
Expand Down Expand Up @@ -184,9 +186,9 @@ def test_get_non_existing_var_should_raise_key_error(self):
with pytest.raises(KeyError):
Variable.get("thisIdDoesNotExist")

def test_update_non_existing_var_should_raise_key_error(self):
def test_update_non_existing_var_should_raise_key_error(self, session):
with pytest.raises(KeyError):
Variable.update("thisIdDoesNotExist", "value")
Variable.update(key="thisIdDoesNotExist", value="value", session=session)

def test_get_non_existing_var_with_none_default_should_return_none(self):
assert Variable.get("thisIdDoesNotExist", default_var=None) is None
Expand All @@ -197,42 +199,45 @@ def test_get_non_existing_var_should_not_deserialize_json_default(self):
"thisIdDoesNotExist", default_var=default_value, deserialize_json=True
)

def test_variable_setdefault_round_trip(self):
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_variable_setdefault_round_trip(self, session):
key = "tested_var_setdefault_1_id"
value = "Monday morning breakfast in Paris"
Variable.setdefault(key, value)
Variable.setdefault(key=key, default=value)
assert value == Variable.get(key)

def test_variable_setdefault_round_trip_json(self):
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_variable_setdefault_round_trip_json(self, session):
key = "tested_var_setdefault_2_id"
value = {"city": "Paris", "Happiness": True}
Variable.setdefault(key, value, deserialize_json=True)
Variable.setdefault(key=key, default=value, deserialize_json=True)
assert value == Variable.get(key, deserialize_json=True)

@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode
def test_variable_setdefault_existing_json(self, session):
key = "tested_var_setdefault_2_id"
value = {"city": "Paris", "Happiness": True}
Variable.set(key=key, value=value, serialize_json=True, session=session)
val = Variable.setdefault(key, value, deserialize_json=True)
val = Variable.setdefault(key=key, default=value, deserialize_json=True)
# Check the returned value, and the stored value are handled correctly.
assert value == val
assert value == Variable.get(key, deserialize_json=True)

def test_variable_delete(self):
def test_variable_delete(self, session):
key = "tested_var_delete"
value = "to be deleted"

# No-op if the variable doesn't exist
Variable.delete(key)
Variable.delete(key=key, session=session)
with pytest.raises(KeyError):
Variable.get(key)

# Set the variable
Variable.set(key, value)
Variable.set(key=key, value=value, session=session)
assert value == Variable.get(key)

# Delete the variable
Variable.delete(key)
Variable.delete(key=key, session=session)
with pytest.raises(KeyError):
Variable.get(key)

Expand Down Expand Up @@ -276,15 +281,15 @@ def test_caching_caches(self, mock_ensure_secrets: mock.Mock):
mock_backend.get_variable.assert_called_once() # second call was not made because of cache
assert first == second

def test_cache_invalidation_on_set(self):
def test_cache_invalidation_on_set(self, session):
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"):
a = Variable.get("key") # value is saved in cache
with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env_two"):
b = Variable.get("key") # value from cache is used
assert a == b

# setting a new value invalidates the cache
Variable.set("key", "new_value")
Variable.set(key="key", value="new_value", session=session)

c = Variable.get("key") # cache should not be used

Expand Down Expand Up @@ -312,7 +317,6 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m
)
session.add(var)
session.flush()

# Make sure we re-load it, not just get the cached object back
session.expunge(var)
_secrets_masker().patterns = set()
Expand Down
1 change: 1 addition & 0 deletions tests/providers/cncf/kubernetes/operators/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def test_templates(self, create_task_instance_of_operator, session):
cmds="{{ dag.dag_id }}",
image="{{ dag.dag_id }}",
annotations={"dag-id": "{{ dag.dag_id }}"},
session=session,
)

session.add(ti)
Expand Down