From 84e8eee0d29ec0bef59be53935bad347f8cb69ef Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 25 Feb 2025 22:42:12 +0800 Subject: [PATCH 01/14] AIP-84 | Add Auth for Dag Refactor conftest for api_fastapi and test_dags Add unauthorized 403 test cases Remove PATCH in requires_access Fix unauthorized_test_client, requires_access_dag Add EditableDagsFilterDep, ReadableDagsFilterDep Add permitted_dag_filter for dags API Fix test_security Add OrmFilterClause Fix mypy error --- airflow/api_fastapi/common/db/common.py | 52 +++++++------- airflow/api_fastapi/common/parameters.py | 11 ++- airflow/api_fastapi/core_api/base.py | 23 ++++++ .../core_api/openapi/v1-generated.yaml | 12 ++++ .../core_api/routes/public/dags.py | 17 ++++- airflow/api_fastapi/core_api/security.py | 54 ++++++++++++-- .../core_api/routes/public/test_dags.py | 72 ++++++++++++++++++- tests/api_fastapi/core_api/test_security.py | 11 +-- 8 files changed, 205 insertions(+), 47 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 84297724ce025..363a7fbdc4b0f 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from sqlalchemy.sql import Select - from airflow.api_fastapi.common.parameters import BaseParam + from airflow.api_fastapi.core_api.base import OrmClause def _get_session() -> Session: @@ -47,7 +47,7 @@ def _get_session() -> Session: def apply_filters_to_select( - *, statement: Select, filters: Sequence[BaseParam | None] | None = None + *, statement: Select, filters: Sequence[OrmClause | None] | None = None ) -> Select: if filters is None: return statement @@ -71,10 +71,10 @@ async def _get_async_session() -> AsyncSession: async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -84,10 +84,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -96,10 +96,10 @@ async def paginated_select_async( async def paginated_select_async( *, statement: Select, - filters: Sequence[BaseParam | None] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause | None] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: AsyncSession, return_total_entries: bool = True, ) -> tuple[Select, int | None]: @@ -129,10 +129,10 @@ async def paginated_select_async( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[True] = True, ) -> tuple[Select, int]: ... @@ -142,10 +142,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: Literal[False], ) -> tuple[Select, None]: ... @@ -155,10 +155,10 @@ def paginated_select( def paginated_select( *, statement: Select, - filters: Sequence[BaseParam] | None = None, - order_by: BaseParam | None = None, - offset: BaseParam | None = None, - limit: BaseParam | None = None, + filters: Sequence[OrmClause] | None = None, + order_by: OrmClause | None = None, + offset: OrmClause | None = None, + limit: OrmClause | None = None, session: Session = NEW_SESSION, return_total_entries: bool = True, ) -> tuple[Select, int | None]: diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index f69d64bd5d031..d7ce038d10411 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -40,6 +40,7 @@ from sqlalchemy import Column, and_, case, or_ from sqlalchemy.inspection import inspect +from airflow.api_fastapi.core_api.base import OrmClause from airflow.models import Base from airflow.models.asset import ( AssetAliasModel, @@ -65,18 +66,14 @@ T = TypeVar("T") -class BaseParam(Generic[T], ABC): - """Base class for filters.""" +class BaseParam(OrmClause[T], ABC): + """Base class for path or query parameters with ORM transformation.""" def __init__(self, value: T | None = None, skip_none: bool = True) -> None: - self.value = value + super().__init__(value) self.attribute: ColumnElement | None = None self.skip_none = skip_none - @abstractmethod - def to_orm(self, select: Select) -> Select: - pass - def set_value(self, value: T | None) -> Self: self.value = value return self diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py index d88ec1757eb60..887f528f197ef 100644 --- a/airflow/api_fastapi/core_api/base.py +++ b/airflow/api_fastapi/core_api/base.py @@ -16,8 +16,16 @@ # under the License. from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Generic, TypeVar + from pydantic import BaseModel as PydanticBaseModel, ConfigDict +if TYPE_CHECKING: + from sqlalchemy.sql import Select + +T = TypeVar("T") + class BaseModel(PydanticBaseModel): """ @@ -39,3 +47,18 @@ class StrictBaseModel(BaseModel): """ model_config = ConfigDict(from_attributes=True, populate_by_name=True, extra="forbid") + + +class OrmClause(Generic[T], ABC): + """ + Base class for filtering clauses with paginated_select. + + The subclasses should implement the `to_orm` method and set the `value` attribute. + """ + + def __init__(self, value: T | None = None): + self.value = value + + @abstractmethod + def to_orm(self, select: Select) -> Select: + pass diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 7555e3e478856..ae00f8b688179 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -3058,6 +3058,8 @@ paths: summary: Get Dags description: Get all DAGs. operationId: get_dags + security: + - OAuth2PasswordBearer: [] parameters: - name: limit in: query @@ -3223,6 +3225,8 @@ paths: summary: Patch Dags description: Patch multiple DAGs. operationId: patch_dags + security: + - OAuth2PasswordBearer: [] parameters: - name: update_mask in: query @@ -3358,6 +3362,8 @@ paths: summary: Get Dag description: Get basic information about a DAG. operationId: get_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3408,6 +3414,8 @@ paths: summary: Patch Dag description: Patch the specific DAG. operationId: patch_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3474,6 +3482,8 @@ paths: summary: Delete Dag description: Delete the specific DAG. operationId: delete_dag + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -3524,6 +3534,8 @@ paths: summary: Get Dag Details description: Get details of DAG. operationId: get_dag_details + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 3eaf487948226..f38f650657cae 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -57,6 +57,11 @@ DAGResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ( + EditableDagsFilterDep, + ReadableDagsFilterDep, + requires_access_dag, +) from airflow.exceptions import AirflowException, DagNotFound from airflow.models import DAG, DagModel from airflow.models.dagrun import DagRun @@ -64,7 +69,7 @@ dags_router = AirflowRouter(tags=["DAG"], prefix="/dags") -@dags_router.get("") +@dags_router.get("", dependencies=[Depends(requires_access_dag(method="GET"))]) def get_dags( limit: QueryLimit, offset: QueryOffset, @@ -104,6 +109,7 @@ def get_dags( ).dynamic_depends() ), ], + readable_dags_filter: ReadableDagsFilterDep, session: SessionDep, ) -> DAGCollectionResponse: """Get all DAGs.""" @@ -131,6 +137,7 @@ def get_dags( tags, owners, last_dag_run_state, + readable_dags_filter, ], order_by=order_by, offset=offset, @@ -155,6 +162,7 @@ def get_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), + dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: """Get basic information about a DAG.""" @@ -181,6 +189,7 @@ def get_dag(dag_id: str, session: SessionDep, request: Request) -> DAGResponse: status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET"))], ) def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDetailsResponse: """Get details of DAG.""" @@ -207,6 +216,7 @@ def get_dag_details(dag_id: str, session: SessionDep, request: Request) -> DAGDe status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="PUT"))], ) def patch_dag( dag_id: str, @@ -249,6 +259,7 @@ def patch_dag( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="PUT"))], ) def patch_dags( patch_body: DAGPatchBody, @@ -260,6 +271,7 @@ def patch_dags( only_active: QueryOnlyActiveFilter, paused: QueryPausedFilter, last_dag_run_state: QueryLastDagRunStateFilter, + editable_dags_filter: EditableDagsFilterDep, session: SessionDep, update_mask: list[str] | None = Query(None), ) -> DAGCollectionResponse: @@ -280,7 +292,7 @@ def patch_dags( dags_select, total_entries = paginated_select( statement=generate_dag_with_latest_run_query(), - filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state], + filters=[only_active, paused, dag_id_pattern, tags, owners, last_dag_run_state, editable_dags_filter], order_by=None, offset=offset, limit=limit, @@ -310,6 +322,7 @@ def patch_dags( status.HTTP_422_UNPROCESSABLE_ENTITY, ] ), + dependencies=[Depends(requires_access_dag(method="DELETE"))], ) def delete_dag( dag_id: str, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 2e5745ffed272..bd6f26ef8dbce 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from collections.abc import Container from functools import cache from typing import TYPE_CHECKING, Annotated, Callable @@ -36,11 +37,15 @@ PoolDetails, VariableDetails, ) +from airflow.api_fastapi.core_api.base import OrmClause from airflow.configuration import conf +from airflow.models.dag import DagModel from airflow.utils.jwt_signer import JWTSigner, get_signing_key if TYPE_CHECKING: - from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod + from sqlalchemy.sql import Select + + from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -63,6 +68,9 @@ def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser: raise HTTPException(status.HTTP_403_FORBIDDEN, "Forbidden") +GetUserDep = Annotated[BaseUser, Depends(get_user)] + + async def get_user_with_exception_handling(request: Request) -> BaseUser | None: # Currently the UI does not support JWT authentication, this method defines a fallback if no token is provided by the UI. # We can remove this method when issue https://github.com/apache/airflow/issues/44884 is done. @@ -80,12 +88,14 @@ async def get_user_with_exception_handling(request: Request) -> BaseUser | None: return get_user(token_str) -def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable: +def requires_access_dag( + method: ResourceMethod, access_entity: DagAccessEntity | None = None +) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: - dag_id = request.path_params.get("dag_id") or request.query_params.get("dag_id") + dag_id: str | None = request.path_params.get("dag_id") _requires_access( is_authorized_callback=lambda: get_auth_manager().is_authorized_dag( @@ -96,10 +106,42 @@ def inner( return inner +class PermittedDagFilter(OrmClause[set[str]]): + """A parameter that filters the permitted dags for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(DagModel.dag_id.in_(self.value)) + + +def permitted_dag_filter_factory( + methods: Container[ResourceMethod], +) -> Callable[[Request, BaseUser], PermittedDagFilter]: + """ + Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. + + :param methods: whether filter readable or writable. + :return: The callable that can be used as Depends in FastAPI. + """ + + def depends_permitted_dags_filter( + request: Request, + user: GetUserDep, + ) -> PermittedDagFilter: + auth_manager: BaseAuthManager = request.app.state.auth_manager + permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, methods=methods) + return PermittedDagFilter(permitted_dags) + + return depends_permitted_dags_filter + + +EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["PUT"]))] +ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["GET"]))] + + def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: pool_name = request.path_params.get("pool_name") @@ -115,7 +157,7 @@ def inner( def requires_access_connection(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: def inner( request: Request, - user: Annotated[BaseUser, Depends(get_user)], + user: GetUserDep, ) -> None: connection_id = request.path_params.get("connection_id") diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 9c4650636b22f..c1793ed7fa5d6 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -234,13 +234,31 @@ class TestGetDags(TestDagEndpoint): ) def test_get_dags(self, test_client, query_params, expected_total_entries, expected_ids): response = test_client.get("/public/dags", params=query_params) - assert response.status_code == 200 body = response.json() assert body["total_entries"] == expected_total_entries assert [dag["dag_id"] for dag in body["dags"]] == expected_ids + @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): + mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} + response = test_client.get("/public/dags") + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["GET"]) + assert response.status_code == 200 + body = response.json() + + assert body["total_entries"] == 2 + assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] + + def test_get_dags_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/dags") + assert response.status_code == 401 + + def test_get_dags_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/dags") + assert response.status_code == 403 + class TestPatchDag(TestDagEndpoint): """Unit tests for Patch DAG.""" @@ -266,6 +284,14 @@ def test_patch_dag( body = response.json() assert body["is_paused"] == expected_is_paused + def test_patch_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) + assert response.status_code == 401 + + def test_patch_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch(f"/public/dags/{DAG1_ID}", json={"is_paused": True}) + assert response.status_code == 403 + class TestPatchDags(TestDagEndpoint): """Unit tests for Patch DAGs.""" @@ -323,6 +349,26 @@ def test_patch_dags( paused_dag_ids = [dag["dag_id"] for dag in body["dags"] if dag["is_paused"]] assert paused_dag_ids == expected_paused_ids + @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): + mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} + response = test_client.patch( + "/public/dags", json={"is_paused": False}, params={"only_active": False, "dag_id_pattern": "~"} + ) + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["PUT"]) + assert response.status_code == 200 + body = response.json() + + assert [dag["dag_id"] for dag in body["dags"]] == [DAG1_ID, DAG2_ID] + + def test_patch_dags_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/dags", json={"is_paused": True}) + assert response.status_code == 401 + + def test_patch_dags_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/dags", json={"is_paused": True}) + assert response.status_code == 403 + class TestDagDetails(TestDagEndpoint): """Unit tests for DAG Details.""" @@ -404,6 +450,14 @@ def test_dag_details( } assert res_json == expected + def test_dag_details_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/details") + assert response.status_code == 401 + + def test_dag_details_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/details") + assert response.status_code == 403 + class TestGetDag(TestDagEndpoint): """Unit tests for Get DAG.""" @@ -452,6 +506,14 @@ def test_get_dag(self, test_client, query_params, dag_id, expected_status_code, } assert res_json == expected + def test_get_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}") + assert response.status_code == 401 + + def test_get_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}") + assert response.status_code == 403 + class TestDeleteDAG(TestDagEndpoint): """Unit tests for Delete DAG.""" @@ -510,3 +572,11 @@ def test_delete_dag( details_response = test_client.get(f"{API_PREFIX}/{dag_id}/details") assert details_response.status_code == status_code_details + + def test_delete_dag_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") + assert response.status_code == 401 + + def test_delete_dag_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"{API_PREFIX}/{DAG1_ID}") + assert response.status_code == 403 diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py index f777cda5e83f0..193cc67798f91 100644 --- a/tests/api_fastapi/core_api/test_security.py +++ b/tests/api_fastapi/core_api/test_security.py @@ -88,11 +88,10 @@ def test_requires_access_dag_authorized(self, mock_get_auth_manager): auth_manager = Mock() auth_manager.is_authorized_dag.return_value = True mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params.return_value = {} - mock_request = Mock() - mock_request.path_params.return_value = {"dag_id": "test"} - - requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) auth_manager.is_authorized_dag.assert_called_once() @@ -101,11 +100,13 @@ def test_requires_access_dag_unauthorized(self, mock_get_auth_manager): auth_manager = Mock() auth_manager.is_authorized_dag.return_value = False mock_get_auth_manager.return_value = auth_manager + fastapi_request = Mock() + fastapi_request.path_params.return_value = {} mock_request = Mock() mock_request.path_params.return_value = {} with pytest.raises(HTTPException, match="Forbidden"): - requires_access_dag("GET", DagAccessEntity.CODE)(mock_request, Mock()) + requires_access_dag("GET", DagAccessEntity.CODE)(fastapi_request, Mock()) auth_manager.is_authorized_dag.assert_called_once() From 8cfc8160a0f1fe4702910dc60d4b8a5e6713ae4e Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Fri, 7 Mar 2025 17:50:51 +0800 Subject: [PATCH 02/14] fix(api_fastapi): rename methods argument to method --- airflow/api_fastapi/core_api/security.py | 13 +++++-------- .../api_fastapi/core_api/routes/public/test_dags.py | 4 ++-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index bd6f26ef8dbce..7ef10be39e3e2 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -from collections.abc import Container from functools import cache from typing import TYPE_CHECKING, Annotated, Callable @@ -113,13 +112,11 @@ def to_orm(self, select: Select) -> Select: return select.where(DagModel.dag_id.in_(self.value)) -def permitted_dag_filter_factory( - methods: Container[ResourceMethod], -) -> Callable[[Request, BaseUser], PermittedDagFilter]: +def permitted_dag_filter_factory(method: ResourceMethod) -> Callable[[Request, BaseUser], PermittedDagFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. - :param methods: whether filter readable or writable. + :param method: whether filter readable or writable. :return: The callable that can be used as Depends in FastAPI. """ @@ -128,14 +125,14 @@ def depends_permitted_dags_filter( user: GetUserDep, ) -> PermittedDagFilter: auth_manager: BaseAuthManager = request.app.state.auth_manager - permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, methods=methods) + permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, method=method) return PermittedDagFilter(permitted_dags) return depends_permitted_dags_filter -EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["PUT"]))] -ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory(["GET"]))] +EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("PUT"))] +ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("GET"))] def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index c1793ed7fa5d6..424588b227afd 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -244,7 +244,7 @@ def test_get_dags(self, test_client, query_params, expected_total_entries, expec def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} response = test_client.get("/public/dags") - mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["GET"]) + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="GET") assert response.status_code == 200 body = response.json() @@ -355,7 +355,7 @@ def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_i response = test_client.patch( "/public/dags", json={"is_paused": False}, params={"only_active": False, "dag_id_pattern": "~"} ) - mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, methods=["PUT"]) + mock_get_permitted_dag_ids.assert_called_once_with(user=mock.ANY, method="PUT") assert response.status_code == 200 body = response.json() From 4762eaabac006d07bb0fe6e4da20f435b9424533 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sat, 8 Mar 2025 02:41:14 +0800 Subject: [PATCH 03/14] Fix kubernetes_tests --- kubernetes_tests/test_base.py | 56 ++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 31e1924c18ad8..d8ca7b1b80552 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -25,6 +25,7 @@ from datetime import datetime, timezone from pathlib import Path from subprocess import check_call, check_output +from urllib.parse import parse_qs, urlparse import pytest import requests @@ -126,9 +127,62 @@ def _delete_airflow_pod(name=""): if names: check_call(["kubectl", "delete", "pod", names[0]]) + @staticmethod + def _get_jwt_token(session: requests.Session, username: str, password: str) -> str: + """Get the JWT token for the given username and password. + + Note: API server is still using FAB Auth Manager. + + Steps: + 1. Get the login page to get the csrf token + - The csrf token is in the hidden input field with id "csrf_token" + 2. Login with the username and password + - Must use the same session to keep the csrf token session + 3. Extract the JWT token from the redirect url + - Expected to have a connection error + - The redirect url should have the JWT token as a query parameter + + :param session: The session to use for the request + :param username: The username to use for the login + :param password: The password to use for the login + :return: The JWT token + """ + # get csrf token from login page + get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login") + # input id="csrf_token" + csrf_token = re.search( + r'', + get_login_form_response.text, + ) + assert csrf_token, "Failed to get csrf token from login page" + csrf_token_str = csrf_token.group(1) + assert csrf_token_str, "Failed to get csrf token from login page" + try: + # login with form data + session.post( + f"http://{KUBERNETES_HOST_PORT}/auth/login", + data={"username": username, "password": password, "csrf_token": csrf_token_str}, + ) + except requests.exceptions.ConnectionError as e: + # expected to have a connection error + # currently, the login page redirects to http://localhost:8080/?token=... with status code 308 + # but the KUBERNETES_HOST_PORT is *not* localhost:8080 + # TODO: remove this try/except block when the redirect url is fixed + redirect_url = e.request.url if e.request else None + # ensure redirect_url is a string + redirect_url_str = str(redirect_url) if redirect_url is not None else "" + assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" + parsed_url = urlparse(redirect_url_str) + query_params = parse_qs(str(parsed_url.query)) + jwt_token_list = query_params.get("token") + jwt_token = jwt_token_list[0] if jwt_token_list else None + assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" + return jwt_token + def _get_session_with_retries(self): session = requests.Session() - session.auth = ("admin", "admin") + jwt_token = self._get_jwt_token(session, "admin", "admin") + session.headers.update({"Authorization": f"Bearer {jwt_token}"}) retries = Retry( total=3, backoff_factor=10, From 7429c0c1e6d6d7f06c5b2a6af0b6045261031e95 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sat, 8 Mar 2025 17:06:51 +0800 Subject: [PATCH 04/14] Fix api_fastapi/test_dags --- tests/api_fastapi/core_api/routes/public/test_dags.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/api_fastapi/core_api/routes/public/test_dags.py b/tests/api_fastapi/core_api/routes/public/test_dags.py index 424588b227afd..5ebb14c75204c 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dags.py +++ b/tests/api_fastapi/core_api/routes/public/test_dags.py @@ -240,7 +240,7 @@ def test_get_dags(self, test_client, query_params, expected_total_entries, expec assert body["total_entries"] == expected_total_entries assert [dag["dag_id"] for dag in body["dags"]] == expected_ids - @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") def test_get_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} response = test_client.get("/public/dags") @@ -349,7 +349,7 @@ def test_patch_dags( paused_dag_ids = [dag["dag_id"] for dag in body["dags"] if dag["is_paused"]] assert paused_dag_ids == expected_paused_ids - @mock.patch("airflow.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") + @mock.patch("airflow.api_fastapi.auth.managers.base_auth_manager.BaseAuthManager.get_permitted_dag_ids") def test_patch_dags_should_call_permitted_dag_ids(self, mock_get_permitted_dag_ids, test_client): mock_get_permitted_dag_ids.return_value = {DAG1_ID, DAG2_ID} response = test_client.patch( From 2ce1909b904743157189223acd92c5a0dc9c6ac9 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sat, 8 Mar 2025 22:10:20 +0800 Subject: [PATCH 05/14] Add dags_reserialize for k8s tests Refactor _get_jwt_token --- kubernetes_tests/test_base.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index d8ca7b1b80552..e48943ac529b6 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -63,6 +63,11 @@ def base_tests_setup(self, request): self.set_api_server_base_url_config() self.rollout_restart_deployment("airflow-api-server") self.ensure_deployment_health("airflow-api-server") + + # Sometimes the DAGs are not serialized yet, will cause the trigger dagRun API to get 404 error + # so we need to do it manually + self.dags_reserialize() + # Replacement for unittests.TestCase.id() self.test_id = f"{request.node.cls.__name__}_{request.node.name}" self.session = self._get_session_with_retries() @@ -128,7 +133,7 @@ def _delete_airflow_pod(name=""): check_call(["kubectl", "delete", "pod", names[0]]) @staticmethod - def _get_jwt_token(session: requests.Session, username: str, password: str) -> str: + def _get_jwt_token(username: str, password: str) -> str: """Get the JWT token for the given username and password. Note: API server is still using FAB Auth Manager. @@ -148,6 +153,7 @@ def _get_jwt_token(session: requests.Session, username: str, password: str) -> s :return: The JWT token """ # get csrf token from login page + session = requests.Session() get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login") # input id="csrf_token" csrf_token = re.search( @@ -180,8 +186,8 @@ def _get_jwt_token(session: requests.Session, username: str, password: str) -> s return jwt_token def _get_session_with_retries(self): + jwt_token = self._get_jwt_token("admin", "admin") session = requests.Session() - jwt_token = self._get_jwt_token(session, "admin", "admin") session.headers.update({"Authorization": f"Bearer {jwt_token}"}) retries = Retry( total=3, @@ -319,6 +325,23 @@ def set_api_server_base_url_config(self): ] ) + @staticmethod + def dags_reserialize(): + """Reserialize the DAGs using the airflow CLI in the scheduler pod.""" + check_call( + [ + "kubectl", + "-n", + "airflow", + "exec", + "deployment/airflow-scheduler", + "--", + "airflow", + "dags", + "reserialize", + ] + ) + def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0 state = "" From 960d62c5e32bb9006ae03d0aa213d1852f64e56f Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sun, 9 Mar 2025 23:29:05 +0800 Subject: [PATCH 06/14] Increase threshold of test_integration_run_dag_with_scheduler_failure --- kubernetes_tests/test_kubernetes_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py index 63d1389b17103..3858d3ddd55ac 100644 --- a/kubernetes_tests/test_kubernetes_executor.py +++ b/kubernetes_tests/test_kubernetes_executor.py @@ -50,7 +50,7 @@ def test_integration_run_dag(self): timeout=300, ) - @pytest.mark.execution_timeout(300) + @pytest.mark.execution_timeout(400) def test_integration_run_dag_with_scheduler_failure(self): dag_id = "example_kubernetes_executor" From d9215ff2e0641aa30bb172e5a68de377930f9337 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 10 Mar 2025 13:20:48 +0800 Subject: [PATCH 07/14] test: raise if we cannot get jwt_token not due to connection error --- kubernetes_tests/test_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index e48943ac529b6..080cca93dee6e 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -155,7 +155,6 @@ def _get_jwt_token(username: str, password: str) -> str: # get csrf token from login page session = requests.Session() get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login") - # input id="csrf_token" csrf_token = re.search( r'', get_login_form_response.text, @@ -182,6 +181,8 @@ def _get_jwt_token(username: str, password: str) -> str: query_params = parse_qs(str(parsed_url.query)) jwt_token_list = query_params.get("token") jwt_token = jwt_token_list[0] if jwt_token_list else None + else: + raise assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" return jwt_token From 5f1a09087c6526770a53e5ff7397493008cee8bf Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 10 Mar 2025 14:11:08 +0800 Subject: [PATCH 08/14] Fix _get_jwt_token after dynamic patching k8s configMap --- kubernetes_tests/test_base.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 080cca93dee6e..e33e435ee92a8 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -162,27 +162,19 @@ def _get_jwt_token(username: str, password: str) -> str: assert csrf_token, "Failed to get csrf token from login page" csrf_token_str = csrf_token.group(1) assert csrf_token_str, "Failed to get csrf token from login page" - try: - # login with form data - session.post( - f"http://{KUBERNETES_HOST_PORT}/auth/login", - data={"username": username, "password": password, "csrf_token": csrf_token_str}, - ) - except requests.exceptions.ConnectionError as e: - # expected to have a connection error - # currently, the login page redirects to http://localhost:8080/?token=... with status code 308 - # but the KUBERNETES_HOST_PORT is *not* localhost:8080 - # TODO: remove this try/except block when the redirect url is fixed - redirect_url = e.request.url if e.request else None - # ensure redirect_url is a string - redirect_url_str = str(redirect_url) if redirect_url is not None else "" - assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" - parsed_url = urlparse(redirect_url_str) - query_params = parse_qs(str(parsed_url.query)) - jwt_token_list = query_params.get("token") - jwt_token = jwt_token_list[0] if jwt_token_list else None - else: - raise + # login with form data + login_response = session.post( + f"http://{KUBERNETES_HOST_PORT}/auth/login", + data={"username": username, "password": password, "csrf_token": csrf_token_str}, + ) + redirect_url = login_response.url + # ensure redirect_url is a string + redirect_url_str = str(redirect_url) if redirect_url is not None else "" + assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" + parsed_url = urlparse(redirect_url_str) + query_params = parse_qs(str(parsed_url.query)) + jwt_token_list = query_params.get("token") + jwt_token = jwt_token_list[0] if jwt_token_list else None assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" return jwt_token From db602b0fc422694fd78214e851bb46519c479ac5 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 10 Mar 2025 17:05:03 +0800 Subject: [PATCH 09/14] Remove dags_reserialize setup in BaseK8STest --- kubernetes_tests/test_base.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index e33e435ee92a8..025777b01172b 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -64,10 +64,6 @@ def base_tests_setup(self, request): self.rollout_restart_deployment("airflow-api-server") self.ensure_deployment_health("airflow-api-server") - # Sometimes the DAGs are not serialized yet, will cause the trigger dagRun API to get 404 error - # so we need to do it manually - self.dags_reserialize() - # Replacement for unittests.TestCase.id() self.test_id = f"{request.node.cls.__name__}_{request.node.name}" self.session = self._get_session_with_retries() @@ -318,23 +314,6 @@ def set_api_server_base_url_config(self): ] ) - @staticmethod - def dags_reserialize(): - """Reserialize the DAGs using the airflow CLI in the scheduler pod.""" - check_call( - [ - "kubectl", - "-n", - "airflow", - "exec", - "deployment/airflow-scheduler", - "--", - "airflow", - "dags", - "reserialize", - ] - ) - def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0 state = "" From 722d0ad115cd09b324d919118b70e01be97763dc Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 10 Mar 2025 19:56:46 +0800 Subject: [PATCH 10/14] Fix test_docker_compose_quick_start --- .../test_docker_compose_quick_start.py | 57 ++++++++++++++++++- 1 file changed, 54 insertions(+), 3 deletions(-) diff --git a/docker_tests/test_docker_compose_quick_start.py b/docker_tests/test_docker_compose_quick_start.py index c91dbb36dba6d..ccf4ef18fcd2e 100644 --- a/docker_tests/test_docker_compose_quick_start.py +++ b/docker_tests/test_docker_compose_quick_start.py @@ -18,10 +18,12 @@ import json import os +import re import shlex from pprint import pprint from shutil import copyfile from time import sleep +from urllib.parse import parse_qs, urlparse import pytest import requests @@ -34,18 +36,67 @@ # isort:on (needed to workaround isort bug) +DOCKER_COMPOSE_HOST_PORT = os.environ.get("HOST_PORT", "localhost:8080") AIRFLOW_WWW_USER_USERNAME = os.environ.get("_AIRFLOW_WWW_USER_USERNAME", "airflow") AIRFLOW_WWW_USER_PASSWORD = os.environ.get("_AIRFLOW_WWW_USER_PASSWORD", "airflow") DAG_ID = "example_bash_operator" DAG_RUN_ID = "test_dag_run_id" -def api_request(method: str, path: str, base_url: str = "http://localhost:8080/public", **kwargs) -> dict: +def get_jwt_token() -> str: + """Get the JWT token. + + Note: API server is still using FAB Auth Manager. + + Steps: + 1. Get the login page to get the csrf token + - The csrf token is in the hidden input field with id "csrf_token" + 2. Login with the username and password + - Must use the same session to keep the csrf token session + 3. Extract the JWT token from the redirect url + - Expected to have a connection error + - The redirect url should have the JWT token as a query parameter + + :return: The JWT token + """ + # get csrf token from login page + session = requests.Session() + get_login_form_response = session.get(f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login") + csrf_token = re.search( + r'', + get_login_form_response.text, + ) + assert csrf_token, "Failed to get csrf token from login page" + csrf_token_str = csrf_token.group(1) + assert csrf_token_str, "Failed to get csrf token from login page" + # login with form data + login_response = session.post( + f"http://{DOCKER_COMPOSE_HOST_PORT}/auth/login", + data={ + "username": AIRFLOW_WWW_USER_USERNAME, + "password": AIRFLOW_WWW_USER_PASSWORD, + "csrf_token": csrf_token_str, + }, + ) + redirect_url = login_response.url + # ensure redirect_url is a string + redirect_url_str = str(redirect_url) if redirect_url is not None else "" + assert "/?token" in redirect_url_str, f"Login failed with redirect url {redirect_url_str}" + parsed_url = urlparse(redirect_url_str) + query_params = parse_qs(str(parsed_url.query)) + jwt_token_list = query_params.get("token") + jwt_token = jwt_token_list[0] if jwt_token_list else None + assert jwt_token, f"Failed to get JWT token from redirect url {redirect_url_str}" + return jwt_token + + +def api_request( + method: str, path: str, base_url: str = f"http://{DOCKER_COMPOSE_HOST_PORT}/public", **kwargs +) -> dict: response = requests.request( method=method, url=f"{base_url}/{path}", - auth=(AIRFLOW_WWW_USER_USERNAME, AIRFLOW_WWW_USER_PASSWORD), - headers={"Content-Type": "application/json"}, + headers={"Authorization": f"Bearer {get_jwt_token()}", "Content-Type": "application/json"}, **kwargs, ) response.raise_for_status() From 7359c762f60095e79d8c41687dcc2e856eb5f861 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 10 Mar 2025 21:11:42 +0800 Subject: [PATCH 11/14] Ensure scheduler health in test_integration_run_dag_with_scheduler_failure --- kubernetes_tests/test_other_executors.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/kubernetes_tests/test_other_executors.py b/kubernetes_tests/test_other_executors.py index f8203b069b404..327e252825a37 100644 --- a/kubernetes_tests/test_other_executors.py +++ b/kubernetes_tests/test_other_executors.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -import time - import pytest from kubernetes_tests.test_base import ( @@ -68,8 +66,7 @@ def test_integration_run_dag_with_scheduler_failure(self): dag_run_id, logical_date = self.start_job_in_kubernetes(dag_id, self.host) self._delete_airflow_pod("scheduler") - - time.sleep(10) # give time for pod to restart + self.ensure_deployment_health("airflow-scheduler") # Wait some time for the operator to complete self.monitor_task( From 5313b50ca8b68b1dcc84068df31caccb20e5aab2 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 10 Mar 2025 23:47:45 +0800 Subject: [PATCH 12/14] Increase timeout threshold --- kubernetes_tests/test_base.py | 2 +- kubernetes_tests/test_kubernetes_executor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 025777b01172b..5868593664fb7 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -179,7 +179,7 @@ def _get_session_with_retries(self): session = requests.Session() session.headers.update({"Authorization": f"Bearer {jwt_token}"}) retries = Retry( - total=3, + total=5, backoff_factor=10, status_forcelist=[404], allowed_methods=Retry.DEFAULT_ALLOWED_METHODS | frozenset(["PATCH", "POST"]), diff --git a/kubernetes_tests/test_kubernetes_executor.py b/kubernetes_tests/test_kubernetes_executor.py index 3858d3ddd55ac..8a7596f3cda70 100644 --- a/kubernetes_tests/test_kubernetes_executor.py +++ b/kubernetes_tests/test_kubernetes_executor.py @@ -50,7 +50,7 @@ def test_integration_run_dag(self): timeout=300, ) - @pytest.mark.execution_timeout(400) + @pytest.mark.execution_timeout(500) def test_integration_run_dag_with_scheduler_failure(self): dag_id = "example_kubernetes_executor" From 737f506a40d9000744faa492ec97d2cfb69138fd Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 11 Mar 2025 10:42:12 +0800 Subject: [PATCH 13/14] Add HTTP retry for _get_jwt_token --- kubernetes_tests/test_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index 5868593664fb7..a159c21c3728b 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -149,7 +149,10 @@ def _get_jwt_token(username: str, password: str) -> str: :return: The JWT token """ # get csrf token from login page + retry = Retry(total=5, backoff_factor=10) session = requests.Session() + session.mount("http://", HTTPAdapter(max_retries=retry)) + session.mount("https://", HTTPAdapter(max_retries=retry)) get_login_form_response = session.get(f"http://{KUBERNETES_HOST_PORT}/auth/login") csrf_token = re.search( r'', From 75efb2697b4d29ea5d36ab470804684f10651670 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Tue, 11 Mar 2025 15:45:30 +0800 Subject: [PATCH 14/14] Add JWTRefreshAdapter and restart api-server if needed --- kubernetes_tests/test_base.py | 51 ++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/kubernetes_tests/test_base.py b/kubernetes_tests/test_base.py index a159c21c3728b..31248b02ac420 100644 --- a/kubernetes_tests/test_base.py +++ b/kubernetes_tests/test_base.py @@ -60,9 +60,11 @@ class BaseK8STest: @pytest.fixture(autouse=True) def base_tests_setup(self, request): - self.set_api_server_base_url_config() - self.rollout_restart_deployment("airflow-api-server") - self.ensure_deployment_health("airflow-api-server") + if self.set_api_server_base_url_config(): + # only restart the deployment if the configmap was updated + # speed up the test and make the airflow-api-server deployment more stable + self.rollout_restart_deployment("airflow-api-server") + self.ensure_deployment_health("airflow-api-server") # Replacement for unittests.TestCase.id() self.test_id = f"{request.node.cls.__name__}_{request.node.name}" @@ -178,6 +180,30 @@ def _get_jwt_token(username: str, password: str) -> str: return jwt_token def _get_session_with_retries(self): + class JWTRefreshAdapter(HTTPAdapter): + def __init__(self, base_instance, **kwargs): + self.base_instance = base_instance + super().__init__(**kwargs) + + def send(self, request, **kwargs): + response = super().send(request, **kwargs) + if response.status_code in (401, 403): + # Refresh token and update the Authorization header with retry logic. + attempts = 0 + jwt_token = None + while attempts < 5: + try: + jwt_token = self.base_instance._get_jwt_token("admin", "admin") + break + except Exception: + attempts += 1 + time.sleep(1) + if jwt_token is None: + raise Exception("Failed to refresh JWT token after 5 attempts") + request.headers["Authorization"] = f"Bearer {jwt_token}" + response = super().send(request, **kwargs) + return response + jwt_token = self._get_jwt_token("admin", "admin") session = requests.Session() session.headers.update({"Authorization": f"Bearer {jwt_token}"}) @@ -187,8 +213,9 @@ def _get_session_with_retries(self): status_forcelist=[404], allowed_methods=Retry.DEFAULT_ALLOWED_METHODS | frozenset(["PATCH", "POST"]), ) - session.mount("http://", HTTPAdapter(max_retries=retries)) - session.mount("https://", HTTPAdapter(max_retries=retries)) + adapter = JWTRefreshAdapter(self, max_retries=retries) + session.mount("http://", adapter) + session.mount("https://", adapter) return session def _ensure_airflow_api_server_is_healthy(self): @@ -288,8 +315,11 @@ def _parse_airflow_cfg_dict_as_escaped_toml(self, airflow_cfg_dict: dict) -> str # escape newlines and double quotes return airflow_cfg_str.replace("\n", "\\n").replace('"', '\\"') - def set_api_server_base_url_config(self): - """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap.""" + def set_api_server_base_url_config(self) -> bool: + """Set [api/base_url] with `f"http://{KUBERNETES_HOST_PORT}"` as env in k8s configmap. + + :return: True if the configmap was updated successfully, False otherwise + """ configmap_name = "airflow-config" configmap_key = "airflow.cfg" original_configmap_json_str = check_output( @@ -302,7 +332,7 @@ def set_api_server_base_url_config(self): airflow_cfg_dict = self._parse_airflow_cfg_as_dict(original_airflow_cfg) airflow_cfg_dict["api"]["base_url"] = f"http://{KUBERNETES_HOST_PORT}" # update the configmap with the new airflow.cfg - check_call( + patch_configmap_result = check_output( [ "kubectl", "patch", @@ -315,7 +345,10 @@ def set_api_server_base_url_config(self): "-p", f'{{"data": {{"{configmap_key}": "{self._parse_airflow_cfg_dict_as_escaped_toml(airflow_cfg_dict)}"}}}}', ] - ) + ).decode() + if "(no change)" in patch_configmap_result: + return False + return True def ensure_dag_expected_state(self, host, logical_date, dag_id, expected_final_state, timeout): tries = 0