Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def _initialize_map() -> dict[str, Callable]:
from airflow.models import Trigger, Variable, XCom
from airflow.models.dag import DagModel
from airflow.models.dagwarning import DagWarning
from airflow.secrets.metastore import MetastoreBackend

functions: list[Callable] = [
DagFileProcessor.update_import_errors,
Expand All @@ -46,6 +47,8 @@ def _initialize_map() -> dict[str, Callable]:
DagModel.get_paused_dag_ids,
DagFileProcessorManager.clear_nonexistent_import_errors,
DagWarning.purge_inactive_dag_warnings,
MetastoreBackend._fetch_connection,
MetastoreBackend._fetch_variable,
XCom.get_value,
XCom.get_one,
XCom.get_many,
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import warnings
from json import JSONDecodeError
from typing import Any
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit

from sqlalchemy import Boolean, Column, Integer, String, Text
Expand Down Expand Up @@ -433,6 +434,9 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection:

raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")

def to_dict(self) -> dict[str, Any]:
return {"conn_id": self.conn_id, "description": self.description, "uri": self.get_uri()}

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
kwargs = json.loads(value)
Expand Down
23 changes: 18 additions & 5 deletions airflow/secrets/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from sqlalchemy.orm import Session

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.secrets import BaseSecretsBackend
from airflow.utils.session import NEW_SESSION, provide_session
Expand All @@ -36,11 +37,7 @@ class MetastoreBackend(BaseSecretsBackend):

@provide_session
def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None:
from airflow.models.connection import Connection

conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
session.expunge_all()
return conn
return MetastoreBackend._fetch_connection(conn_id, session=session)

@provide_session
def get_connections(self, conn_id: str, session: Session = NEW_SESSION) -> list[Connection]:
Expand All @@ -63,6 +60,22 @@ def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None:
:param key: Variable Key
:return: Variable Value
"""
return MetastoreBackend._fetch_variable(key=key, session=session)

@staticmethod
@internal_api_call
@provide_session
def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connection | None:
from airflow.models.connection import Connection

conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
session.expunge_all()
return conn

@staticmethod
@internal_api_call
@provide_session
def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None:
from airflow.models.variable import Variable

var_value = session.query(Variable).filter(Variable.key == key).first()
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ class DagAttributeTypes(str, Enum):
TASK_INSTANCE = "task_instance"
DAG_RUN = "dag_run"
DATA_SET = "data_set"
CONNECTION = "connection"
4 changes: 4 additions & 0 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ def serialize(
cls.serialize(var.__dict__, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.SIMPLE_TASK_INSTANCE,
)
elif isinstance(var, Connection):
return cls._encode(var.to_dict(), type_=DAT.CONNECTION)
elif use_pydantic_models and _ENABLE_AIP_44:
if isinstance(var, Job):
return cls._encode(JobPydantic.from_orm(var).dict(), type_=DAT.BASE_JOB)
Expand Down Expand Up @@ -552,6 +554,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return Dataset(**var)
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
return Connection(**var)
elif use_pydantic_models and _ENABLE_AIP_44:
if type_ == DAT.BASE_JOB:
return JobPydantic.parse_obj(var)
Expand Down
70 changes: 30 additions & 40 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from flask import Flask

from airflow.models.connection import Connection
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
Expand Down Expand Up @@ -54,6 +55,10 @@ def factory() -> Flask:
return factory()


def equals(a, b) -> bool:
return a == b


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
class TestRpcApiEndpoint:
@pytest.fixture(autouse=True)
Expand All @@ -71,65 +76,50 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
yield mock_initialize_map

@pytest.mark.parametrize(
"input_data, method_result, method_params, expected_mock, expected_code",
"input_params, method_result, result_cmp_func, method_params",
[
("", None, equals, {}),
("", "test_me", equals, {}),
(
{"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
"test_me",
{},
mock_test_method,
200,
json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})),
("dag_id_15", "fake-task", 1),
equals,
{"dag_id": 15, "task_id": "fake-task"},
),
(
{"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
None,
"",
TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING),
lambda a, b: a == TaskInstancePydantic.from_orm(b),
{},
mock_test_method,
200,
),
(
{
"jsonrpc": "2.0",
"method": TEST_METHOD_NAME,
"params": json.dumps(BaseSerialization.serialize({"dag_id": 15, "task_id": "fake-task"})),
},
("dag_id_15", "fake-task", 1),
{"dag_id": 15, "task_id": "fake-task"},
mock_test_method,
200,
"",
Connection(conn_id="test_conn", conn_type="http", host="", password=""),
lambda a, b: a.get_uri() == b.get_uri() and a.conn_id == b.conn_id,
{},
),
],
)
def test_method(self, input_data, method_result, method_params, expected_mock, expected_code):
def test_method(self, input_params, method_result, result_cmp_func, method_params):
if method_result:
expected_mock.return_value = method_result
mock_test_method.return_value = method_result

input_data = {
"jsonrpc": "2.0",
"method": TEST_METHOD_NAME,
"params": input_params,
}
response = self.client.post(
"/internal_api/v1/rpcapi",
headers={"Content-Type": "application/json"},
data=json.dumps(input_data),
)
assert response.status_code == expected_code
assert response.status_code == 200
if method_result:
response_data = BaseSerialization.deserialize(json.loads(response.data))
assert response_data == method_result

expected_mock.assert_called_once_with(**method_params)
response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True)
assert result_cmp_func(response_data, method_result)

def test_method_with_pydantic_serialized_object(self):
ti = TaskInstance(task=EmptyOperator(task_id="task"), run_id="run_id", state=State.RUNNING)
mock_test_method.return_value = ti

response = self.client.post(
"/internal_api/v1/rpcapi",
headers={"Content-Type": "application/json"},
data=json.dumps({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}),
)
assert response.status_code == 200
print(response.data)
response_data = BaseSerialization.deserialize(json.loads(response.data), use_pydantic_models=True)
expected_data = TaskInstancePydantic.from_orm(ti)
assert response_data == expected_data
mock_test_method.assert_called_once_with(**method_params)

def test_method_with_exception(self):
mock_test_method.side_effect = ValueError("Error!!!")
Expand Down