From c99b6427fbd66af00cf8b95b3ac97787cacdf7ab Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 19 Apr 2023 16:46:25 +0200 Subject: [PATCH 1/2] Migrate fetching variables and connections from database to Internal API --- .../endpoints/rpc_api_endpoint.py | 3 + airflow/models/connection.py | 4 ++ airflow/secrets/metastore.py | 23 ++++-- airflow/serialization/enums.py | 1 + airflow/serialization/serialized_objects.py | 2 + .../endpoints/test_rpc_api_endpoint.py | 70 ++++++++----------- 6 files changed, 58 insertions(+), 45 deletions(-) diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index b6ac604c05393..242def9d9e3f9 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -37,6 +37,7 @@ def _initialize_map() -> dict[str, Callable]: from airflow.models import Trigger, Variable, XCom from airflow.models.dag import DagModel from airflow.models.dagwarning import DagWarning + from airflow.secrets.metastore import MetastoreBackend functions: list[Callable] = [ DagFileProcessor.update_import_errors, @@ -46,6 +47,8 @@ def _initialize_map() -> dict[str, Callable]: DagModel.get_paused_dag_ids, DagFileProcessorManager.clear_nonexistent_import_errors, DagWarning.purge_inactive_dag_warnings, + MetastoreBackend._fetch_connection, + MetastoreBackend._fetch_variable, XCom.get_value, XCom.get_one, XCom.get_many, diff --git a/airflow/models/connection.py b/airflow/models/connection.py index a5653412093e9..92c486092d009 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -21,6 +21,7 @@ import logging import warnings from json import JSONDecodeError +from typing import Any from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit from sqlalchemy import Boolean, Column, Integer, String, Text @@ -433,6 +434,9 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection: raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") + def to_dict(self) -> dict[str, Any]: + return {"conn_id": self.conn_id, "description": self.description, "uri": self.get_uri()} + @classmethod def from_json(cls, value, conn_id=None) -> Connection: kwargs = json.loads(value) diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index dc81675d771ff..d786ffe801d4b 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -23,6 +23,7 @@ from sqlalchemy.orm import Session +from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import RemovedInAirflow3Warning from airflow.secrets import BaseSecretsBackend from airflow.utils.session import NEW_SESSION, provide_session @@ -36,11 +37,7 @@ class MetastoreBackend(BaseSecretsBackend): @provide_session def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None: - from airflow.models.connection import Connection - - conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() - session.expunge_all() - return conn + return MetastoreBackend._fetch_connection(conn_id, session=session) @provide_session def get_connections(self, conn_id: str, session: Session = NEW_SESSION) -> list[Connection]: @@ -63,6 +60,22 @@ def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: :param key: Variable Key :return: Variable Value """ + return MetastoreBackend._fetch_variable(key=key, session=NEW_SESSION) + + @staticmethod + @internal_api_call + @provide_session + def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connection | None: + from airflow.models.connection import Connection + + conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() + session.expunge_all() + return conn + + @staticmethod + @internal_api_call + @provide_session + def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None: from airflow.models.variable import Variable var_value = session.query(Variable).filter(Variable.key == key).first() diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index c83d9f53ef87c..32be3329c1e3d 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum): TASK_INSTANCE = "task_instance" DAG_RUN = "dag_run" DATA_SET = "data_set" + CONNECTION = "connection" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 0f319a1afb363..2168be298c013 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -552,6 +552,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: return Dataset(**var) elif type_ == DAT.SIMPLE_TASK_INSTANCE: return SimpleTaskInstance(**cls.deserialize(var)) + elif type_ == DAT.CONNECTION: + return Connection(**var) elif use_pydantic_models and _ENABLE_AIP_44: if type_ == DAT.BASE_JOB: return JobPydantic.parse_obj(var) diff --git a/tests/api_internal/endpoints/test_rpc_api_endpoint.py b/tests/api_internal/endpoints/test_rpc_api_endpoint.py index bd7a7dc189380..d72448324e930 100644 --- a/tests/api_internal/endpoints/test_rpc_api_endpoint.py +++ b/tests/api_internal/endpoints/test_rpc_api_endpoint.py @@ -23,6 +23,7 @@ import pytest from flask import Flask +from airflow.models.connection import Connection from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -54,6 +55,10 @@ def factory() -> Flask: return factory() +def equals(a, b) -> bool: + return a == b + + @pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled") class TestRpcApiEndpoint: @pytest.fixture(autouse=True) @@ -71,65 +76,50 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator: yield mock_initialize_map @pytest.mark.parametrize( - "input_data, method_result, method_params, expected_mock, expected_code", + "input_params, method_result, result_cmp_func, method_params", [ + ("", None, equals, {}), + ("", "test_me", equals, {}), ( - {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, - "test_me", - {}, - mock_test_method, - 200, + json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), + ("dag_id_15", "fake-task", 1), + equals, + {"dag_id": 15, "task_id": "fake-task"}, ), ( - {"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, - None, + "", + TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING), + lambda a, b: a == TaskInstancePydantic.from_orm(b), {}, - mock_test_method, - 200, ), ( - { - "jsonrpc": "2.0", - "method": TEST_METHOD_NAME, - "params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})), - }, - ("dag_id_15", "fake-task", 1), - {"dag_id": 15, "task_id": "fake-task"}, - mock_test_method, - 200, + "", + Connection(conn_id="test_conn", conn_type="http", host="", password=""), + lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id, + {}, ), ], ) - def test_method(self, input_data, method_result, method_params, expected_mock, expected_code): + def test_method(self, input_params, method_result, result_cmp_func, method_params): if method_result: - expected_mock.return_value = method_result + mock_test_method.return_value = method_result + input_data = { + "jsonrpc": "2.0", + "method": TEST_METHOD_NAME, + "params": input_params, + } response = self.client.post( "/internal_api/v1/rpcapi", headers={"Content-Type": "application/json"}, data=json.dumps(input_data), ) - assert response.status_code == expected_code + assert response.status_code == 200 if method_result: - response_data = BaseSerialization.deserialize(json.loads(response.data)) - assert response_data == method_result - - expected_mock.assert_called_once_with(**method_params) + response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) + assert result_cmp_func(response_data, method_result) - def test_method_with_pydantic_serialized_object(self): - ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING) - mock_test_method.return_value = ti - - response = self.client.post( - "/internal_api/v1/rpcapi", - headers={"Content-Type": "application/json"}, - data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}), - ) - assert response.status_code == 200 - print(response.data) - response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True) - expected_data = TaskInstancePydantic.from_orm(ti) - assert response_data == expected_data + mock_test_method.assert_called_once_with(**method_params) def test_method_with_exception(self): mock_test_method.side_effect = ValueError("Error!!!") From b1d48147c39a110440d7894823b293aefd47ec48 Mon Sep 17 00:00:00 2001 From: Mateusz Henc Date: Wed, 19 Apr 2023 16:50:47 +0200 Subject: [PATCH 2/2] Fixes after merge --- airflow/secrets/metastore.py | 2 +- airflow/serialization/serialized_objects.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index d786ffe801d4b..e5644f8c8e14f 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -60,7 +60,7 @@ def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: :param key: Variable Key :return: Variable Value """ - return MetastoreBackend._fetch_variable(key=key, session=NEW_SESSION) + return MetastoreBackend._fetch_variable(key=key, session=session) @staticmethod @internal_api_call diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 2168be298c013..044b06e111112 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -478,6 +478,8 @@ def serialize( cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models), type_=DAT.SIMPLE_TASK_INSTANCE, ) + elif isinstance(var, Connection): + return cls._encode(var.to_dict(), type_=DAT.CONNECTION) elif use_pydantic_models and _ENABLE_AIP_44: if isinstance(var, Job): return cls._encode(JobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)