From 2252a6b49a45c9070bc84eaac15b0f12eb8a8047 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Mon, 12 Aug 2024 22:50:14 +0200 Subject: [PATCH 1/2] Fix tests/models/test_variable.py for database isolation mode --- .../endpoints/rpc_api_endpoint.py | 9 +-- airflow/api_internal/internal_api_call.py | 2 +- airflow/models/variable.py | 66 ++++++++++++++++++- airflow/serialization/enums.py | 1 + airflow/serialization/serialized_objects.py | 16 ++++- tests/models/test_variable.py | 8 ++- 6 files changed, 90 insertions(+), 12 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index ad65157ef9415..dcbaf9f00c7e3 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -126,9 +126,9 @@ def initialize_method_map() -> dict[str, Callable]: # XCom.get_many, # Not supported because it returns query XCom.clear, XCom.set, - Variable.set, - Variable.update, - Variable.delete, + Variable._set, + Variable._update, + Variable._delete, DAG.fetch_callback, DAG.fetch_dagrun, DagRun.fetch_task_instances, @@ -237,7 +237,8 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: response = json.dumps(output_json) if output_json is not None else None log.info("Sending response: %s", response) return Response(response=response, headers={"Content-Type": "application/json"}) - except AirflowException as e: # In case of AirflowException transport the exception class back to caller + # In case of AirflowException or other selective known types transport the exception class back to caller + except (KeyError, AttributeError, AirflowException) as e: exception_json = BaseSerialization.serialize(e, use_pydantic_models=True) response = json.dumps(exception_json) log.info("Sending exception response: %s", response) diff --git a/airflow/api_internal/internal_api_call.py b/airflow/api_internal/internal_api_call.py index fc0945b3c0fe0..8838377877bec 100644 --- a/airflow/api_internal/internal_api_call.py +++ b/airflow/api_internal/internal_api_call.py @@ -159,7 +159,7 @@ def wrapper(*args, **kwargs): if result is None or result == b"": return None result = BaseSerialization.deserialize(json.loads(result), use_pydantic_models=True) - if isinstance(result, AirflowException): + if isinstance(result, (KeyError, AttributeError, AirflowException)): raise result return result diff --git a/airflow/models/variable.py b/airflow/models/variable.py index 63b71303bc803..563cac46e8c84 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -154,7 +154,6 @@ def get( @staticmethod @provide_session - @internal_api_call def set( key: str, value: Any, @@ -167,6 +166,35 @@ def set( This operation overwrites an existing variable. + :param key: Variable Key + :param value: Value to set for the Variable + :param description: Description of the Variable + :param serialize_json: Serialize the value to a JSON string + :param session: Session + """ + Variable._set( + key=key, value=value, description=description, serialize_json=serialize_json, session=session + ) + # invalidate key in cache for faster propagation + # we cannot save the value set because it's possible that it's shadowed by a custom backend + # (see call to check_for_write_conflict above) + SecretCache.invalidate_variable(key) + + @staticmethod + @provide_session + @internal_api_call + def _set( + key: str, + value: Any, + description: str | None = None, + serialize_json: bool = False, + session: Session = None, + ) -> None: + """ + Set a value for an Airflow Variable with a given Key. + + This operation overwrites an existing variable. + :param key: Variable Key :param value: Value to set for the Variable :param description: Description of the Variable @@ -190,7 +218,6 @@ def set( @staticmethod @provide_session - @internal_api_call def update( key: str, value: Any, @@ -200,6 +227,27 @@ def update( """ Update a given Airflow Variable with the Provided value. + :param key: Variable Key + :param value: Value to set for the Variable + :param serialize_json: Serialize the value to a JSON string + :param session: Session + """ + Variable._update(key=key, value=value, serialize_json=serialize_json, session=session) + # We need to invalidate the cache for internal API cases on the client side + SecretCache.invalidate_variable(key) + + @staticmethod + @provide_session + @internal_api_call + def _update( + key: str, + value: Any, + serialize_json: bool = False, + session: Session = None, + ) -> None: + """ + Update a given Airflow Variable with the Provided value. + :param key: Variable Key :param value: Value to set for the Variable :param serialize_json: Serialize the value to a JSON string @@ -219,11 +267,23 @@ def update( @staticmethod @provide_session - @internal_api_call def delete(key: str, session: Session = None) -> int: """ Delete an Airflow Variable for a given key. + :param key: Variable Keys + """ + rows = Variable._delete(key=key, session=session) + SecretCache.invalidate_variable(key) + return rows + + @staticmethod + @provide_session + @internal_api_call + def _delete(key: str, session: Session = None) -> int: + """ + Delete an Airflow Variable for a given key. + :param key: Variable Keys """ rows = session.execute(delete(Variable).where(Variable.key == key)).rowcount diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index a5bd5e3646e83..f216ce7316103 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -46,6 +46,7 @@ class DagAttributeTypes(str, Enum): RELATIVEDELTA = "relativedelta" BASE_TRIGGER = "base_trigger" AIRFLOW_EXC_SER = "airflow_exc_ser" + BASE_EXC_SER = "base_exc_ser" DICT = "dict" SET = "set" TUPLE = "tuple" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d110271c3da08..a3886aa49acef 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -692,6 +692,15 @@ def serialize( ), type_=DAT.AIRFLOW_EXC_SER, ) + elif isinstance(var, (KeyError, AttributeError)): + return cls._encode( + cls.serialize( + {"exc_cls_name": var.__class__.__name__, "args": [var.args], "kwargs": {}}, + use_pydantic_models=use_pydantic_models, + strict=strict, + ), + type_=DAT.BASE_EXC_SER, + ) elif isinstance(var, BaseTrigger): return cls._encode( cls.serialize(var.serialize(), use_pydantic_models=use_pydantic_models, strict=strict), @@ -834,13 +843,16 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return decode_timezone(var) elif type_ == DAT.RELATIVEDELTA: return decode_relativedelta(var) - elif type_ == DAT.AIRFLOW_EXC_SER: + elif type_ == DAT.AIRFLOW_EXC_SER or type_ == DAT.BASE_EXC_SER: deser = cls.deserialize(var, use_pydantic_models=use_pydantic_models) exc_cls_name = deser["exc_cls_name"] args = deser["args"] kwargs = deser["kwargs"] del deser - exc_cls = import_string(exc_cls_name) + if type_ == DAT.AIRFLOW_EXC_SER: + exc_cls = import_string(exc_cls_name) + else: + exc_cls = import_string(f"builtins.{exc_cls_name}") return exc_cls(*args, **kwargs) elif type_ == DAT.BASE_TRIGGER: tr_cls_name, kwargs = cls.deserialize(var, use_pydantic_models=use_pydantic_models) diff --git a/tests/models/test_variable.py b/tests/models/test_variable.py index 3ec2691e5af95..e9509b4569db4 100644 --- a/tests/models/test_variable.py +++ b/tests/models/test_variable.py @@ -47,6 +47,7 @@ def setup_test_cases(self): db.clear_db_variables() crypto._fernet = None + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @conf_vars({("core", "fernet_key"): "", ("core", "unit_test_mode"): "True"}) def test_variable_no_encryption(self, session): """ @@ -60,6 +61,7 @@ def test_variable_no_encryption(self, session): # should mask anything. That logic is tested in test_secrets_masker.py self.mask_secret.assert_called_once_with("value", "key") + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @conf_vars({("core", "fernet_key"): Fernet.generate_key().decode()}) def test_variable_with_encryption(self, session): """ @@ -70,6 +72,7 @@ def test_variable_with_encryption(self, session): assert test_var.is_encrypted assert test_var.val == "value" + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other fernet @pytest.mark.parametrize("test_value", ["value", ""]) def test_var_with_encryption_rotate_fernet_key(self, test_value, session): """ @@ -152,6 +155,7 @@ def test_variable_update(self, session): Variable.update(key="test_key", value="value2", session=session) assert "value2" == Variable.get("test_key") + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, API server has other ENV def test_variable_update_fails_on_non_metastore_variable(self, session): with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="env-value"): with pytest.raises(AttributeError): @@ -281,6 +285,7 @@ def test_caching_caches(self, mock_ensure_secrets: mock.Mock): mock_backend.get_variable.assert_called_once() # second call was not made because of cache assert first == second + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode, internal API has other env def test_cache_invalidation_on_set(self, session): with mock.patch.dict("os.environ", AIRFLOW_VAR_KEY="from_env"): a = Variable.get("key") # value is saved in cache @@ -316,7 +321,7 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m val=variable_value, ) session.add(var) - session.flush() + session.commit() # Make sure we re-load it, not just get the cached object back session.expunge(var) _secrets_masker().patterns = set() @@ -326,5 +331,4 @@ def test_masking_only_secret_values(variable_value, deserialize_json, expected_m for expected_masked_value in expected_masked_values: assert expected_masked_value in _secrets_masker().patterns finally: - session.rollback() db.clear_db_variables() From 3b657c1655b4a8c8edaf23daf07a73233a5be6d3 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Tue, 13 Aug 2024 00:38:55 +0200 Subject: [PATCH 2/2] Review feedback --- airflow/api_internal/endpoints/rpc_api_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index dcbaf9f00c7e3..e4a5069b29bcc 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -237,7 +237,7 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse: response = json.dumps(output_json) if output_json is not None else None log.info("Sending response: %s", response) return Response(response=response, headers={"Content-Type": "application/json"}) - # In case of AirflowException or other selective known types transport the exception class back to caller + # In case of AirflowException or other selective known types, transport the exception class back to caller except (KeyError, AttributeError, AirflowException) as e: exception_json = BaseSerialization.serialize(e, use_pydantic_models=True) response = json.dumps(exception_json)