From cefd551a320971d0958a5c4806071f9b986e13a2 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Fri, 28 Feb 2025 10:54:07 +0530 Subject: [PATCH 1/6] add auth to pools --- .../core_api/openapi/v1-generated.yaml | 57 ++++++++++++++++++- .../core_api/routes/public/pools.py | 11 +++- airflow/api_fastapi/core_api/security.py | 19 ++++++- airflow/ui/openapi-gen/queries/common.ts | 4 +- airflow/ui/openapi-gen/queries/prefetch.ts | 7 ++- airflow/ui/openapi-gen/queries/queries.ts | 22 +++++-- airflow/ui/openapi-gen/queries/suspense.ts | 10 +++- .../ui/openapi-gen/requests/services.gen.ts | 10 ++++ airflow/ui/openapi-gen/requests/types.gen.ts | 9 ++- tests/api_fastapi/conftest.py | 45 +++++++++++++++ .../core_api/routes/public/test_pools.py | 24 ++++++++ 11 files changed, 201 insertions(+), 17 deletions(-) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index d9d38cd39d2c2..1bff5b4967e97 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4061,12 +4061,16 @@ paths: summary: Delete Pool description: Delete a pool entry. operationId: delete_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path required: true schema: - type: string + anyOf: + - type: string + - type: 'null' title: Pool Name responses: '204': @@ -4107,12 +4111,16 @@ paths: summary: Get Pool description: Get a pool. operationId: get_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path required: true schema: - type: string + anyOf: + - type: string + - type: 'null' title: Pool Name responses: '200': @@ -4151,12 +4159,16 @@ paths: summary: Patch Pool description: Update a Pool. operationId: patch_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path required: true schema: - type: string + anyOf: + - type: string + - type: 'null' title: Pool Name - name: update_mask in: query @@ -4218,7 +4230,17 @@ paths: summary: Get Pools description: Get all pools entries. operationId: get_pools + security: + - OAuth2PasswordBearer: [] parameters: + - name: pool_name + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Pool Name - name: limit in: query required: false @@ -4287,6 +4309,17 @@ paths: summary: Post Pool description: Create a Pool. operationId: post_pool + security: + - OAuth2PasswordBearer: [] + parameters: + - name: pool_name + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Pool Name requestBody: required: true content: @@ -4330,6 +4363,17 @@ paths: summary: Bulk Pools description: Bulk create, update, and delete pools. operationId: bulk_pools + security: + - OAuth2PasswordBearer: [] + parameters: + - name: pool_name + in: query + required: false + schema: + anyOf: + - type: string + - type: 'null' + title: Pool Name requestBody: required: true content: @@ -11307,3 +11351,10 @@ components: - value title: XComUpdateBody description: Payload serializer for updating an XCom entry. + securitySchemes: + OAuth2PasswordBearer: + type: oauth2 + flows: + password: + scopes: {} + tokenUrl: token diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 67079e5542acb..7fd40fb7edd49 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -40,6 +40,7 @@ PoolResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import requires_access_pool from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService from airflow.models.pool import Pool @@ -55,6 +56,7 @@ status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_pool(method="DELETE"))], ) def delete_pool( pool_name: str, @@ -73,6 +75,7 @@ def delete_pool( @pools_router.get( "/{pool_name}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_pool(method="GET"))], ) def get_pool( pool_name: str, @@ -89,6 +92,7 @@ def get_pool( @pools_router.get( "", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_pool(method="GET"))], ) def get_pools( limit: QueryLimit, @@ -126,6 +130,7 @@ def get_pools( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_pool(method="PUT"))], ) def patch_pool( pool_name: str, @@ -177,6 +182,7 @@ def patch_pool( responses=create_openapi_http_exception_doc( [status.HTTP_409_CONFLICT] ), # handled by global exception handler + dependencies=[Depends(requires_access_pool(method="POST"))], ) def post_pool( body: PoolBody, @@ -188,7 +194,10 @@ def post_pool( return pool -@pools_router.patch("") +@pools_router.patch( + "", + dependencies=[Depends(requires_access_pool(method="PUT"))], +) def bulk_pools( request: BulkBody[PoolBody], session: SessionDep, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index eac282e7e3cab..ecdb9079b6154 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -25,7 +25,7 @@ from airflow.api_fastapi.app import get_auth_manager from airflow.auth.managers.models.base_user import BaseUser -from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails, PoolDetails from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner, get_signing_key @@ -82,6 +82,23 @@ def callback(): return inner +def requires_access_pool(method: ResourceMethod) -> Callable: + def inner( + pool_name: str | None = None, + user: Annotated[BaseUser | None, Depends(get_user)] = None, + ) -> None: + def callback(): + return get_auth_manager().is_authorized_dag( + method=method, details=PoolDetails(name=pool_name), user=user + ) + + _requires_access( + is_authorized_callback=callback, + ) + + return inner + + def _requires_access( *, is_authorized_callback: Callable[[], bool], diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 151e3b100a9cf..51dd8f4b24b6d 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1510,15 +1510,17 @@ export const UsePoolServiceGetPoolsKeyFn = ( limit, offset, orderBy, + poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; + poolName?: string; poolNamePattern?: string; } = {}, queryKey?: Array, -) => [usePoolServiceGetPoolsKey, ...(queryKey ?? [{ limit, offset, orderBy, poolNamePattern }])]; +) => [usePoolServiceGetPoolsKey, ...(queryKey ?? [{ limit, offset, orderBy, poolName, poolNamePattern }])]; export type ProviderServiceGetProvidersDefaultResponse = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index c833e4268706b..f3c85af050157 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -2102,6 +2102,7 @@ export const prefetchUsePoolServiceGetPool = ( * Get Pools * Get all pools entries. * @param data The data for the request. + * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2115,17 +2116,19 @@ export const prefetchUsePoolServiceGetPools = ( limit, offset, orderBy, + poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; + poolName?: string; poolNamePattern?: string; } = {}, ) => queryClient.prefetchQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }), + queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolName, poolNamePattern }), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }), }); /** * Get Providers diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 2fdd58c4ec273..c46039ce9129c 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -2471,6 +2471,7 @@ export const usePoolServiceGetPool = < * Get Pools * Get all pools entries. * @param data The data for the request. + * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2487,19 +2488,24 @@ export const usePoolServiceGetPools = < limit, offset, orderBy, + poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; + poolName?: string; poolNamePattern?: string; } = {}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }, queryKey), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }) as TData, + queryKey: Common.UsePoolServiceGetPoolsKeyFn( + { limit, offset, orderBy, poolName, poolNamePattern }, + queryKey, + ), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }) as TData, ...options, }); /** @@ -3334,6 +3340,7 @@ export const useTaskInstanceServicePostClearTaskInstances = < * Create a Pool. * @param data The data for the request. * @param data.requestBody + * @param data.poolName * @returns PoolResponse Successful Response * @throws ApiError */ @@ -3347,6 +3354,7 @@ export const usePoolServicePostPool = < TData, TError, { + poolName?: string; requestBody: PoolBody; }, TContext @@ -3358,11 +3366,13 @@ export const usePoolServicePostPool = < TData, TError, { + poolName?: string; requestBody: PoolBody; }, TContext >({ - mutationFn: ({ requestBody }) => PoolService.postPool({ requestBody }) as unknown as Promise, + mutationFn: ({ poolName, requestBody }) => + PoolService.postPool({ poolName, requestBody }) as unknown as Promise, ...options, }); /** @@ -4141,6 +4151,7 @@ export const usePoolServicePatchPool = < * Bulk create, update, and delete pools. * @param data The data for the request. * @param data.requestBody + * @param data.poolName * @returns BulkResponse Successful Response * @throws ApiError */ @@ -4154,6 +4165,7 @@ export const usePoolServiceBulkPools = < TData, TError, { + poolName?: string; requestBody: BulkBody_PoolBody_; }, TContext @@ -4165,11 +4177,13 @@ export const usePoolServiceBulkPools = < TData, TError, { + poolName?: string; requestBody: BulkBody_PoolBody_; }, TContext >({ - mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise, + mutationFn: ({ poolName, requestBody }) => + PoolService.bulkPools({ poolName, requestBody }) as unknown as Promise, ...options, }); /** diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 48b17556fd9ec..d408243fce83b 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -2448,6 +2448,7 @@ export const usePoolServiceGetPoolSuspense = < * Get Pools * Get all pools entries. * @param data The data for the request. + * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2464,19 +2465,24 @@ export const usePoolServiceGetPoolsSuspense = < limit, offset, orderBy, + poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; + poolName?: string; poolNamePattern?: string; } = {}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }, queryKey), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }) as TData, + queryKey: Common.UsePoolServiceGetPoolsKeyFn( + { limit, offset, orderBy, poolName, poolNamePattern }, + queryKey, + ), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }) as TData, ...options, }); /** diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 81b649d31cb97..54a947342b5f5 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -2917,6 +2917,7 @@ export class PoolService { * Get Pools * Get all pools entries. * @param data The data for the request. + * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2929,6 +2930,7 @@ export class PoolService { method: "GET", url: "/public/pools", query: { + pool_name: data.poolName, limit: data.limit, offset: data.offset, order_by: data.orderBy, @@ -2948,6 +2950,7 @@ export class PoolService { * Create a Pool. * @param data The data for the request. * @param data.requestBody + * @param data.poolName * @returns PoolResponse Successful Response * @throws ApiError */ @@ -2955,6 +2958,9 @@ export class PoolService { return __request(OpenAPI, { method: "POST", url: "/public/pools", + query: { + pool_name: data.poolName, + }, body: data.requestBody, mediaType: "application/json", errors: { @@ -2971,6 +2977,7 @@ export class PoolService { * Bulk create, update, and delete pools. * @param data The data for the request. * @param data.requestBody + * @param data.poolName * @returns BulkResponse Successful Response * @throws ApiError */ @@ -2978,6 +2985,9 @@ export class PoolService { return __request(OpenAPI, { method: "PATCH", url: "/public/pools", + query: { + pool_name: data.poolName, + }, body: data.requestBody, mediaType: "application/json", errors: { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 2ca718644c5fd..cab4f8a3a6bae 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2346,19 +2346,19 @@ export type GetPluginsData = { export type GetPluginsResponse = PluginCollectionResponse; export type DeletePoolData = { - poolName: string; + poolName: string | null; }; export type DeletePoolResponse = void; export type GetPoolData = { - poolName: string; + poolName: string | null; }; export type GetPoolResponse = PoolResponse; export type PatchPoolData = { - poolName: string; + poolName: string | null; requestBody: PoolPatchBody; updateMask?: Array | null; }; @@ -2369,18 +2369,21 @@ export type GetPoolsData = { limit?: number; offset?: number; orderBy?: string; + poolName?: string | null; poolNamePattern?: string | null; }; export type GetPoolsResponse = PoolCollectionResponse; export type PostPoolData = { + poolName?: string | null; requestBody: PoolBody; }; export type PostPoolResponse = PoolResponse; export type BulkPoolsData = { + poolName?: string | null; requestBody: BulkBody_PoolBody_; }; diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 3e5b44651f0dd..b9070b199337d 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -23,6 +23,8 @@ from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app +from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.models import Connection from airflow.models.dag_version import DagVersion from airflow.models.serialized_dag import SerializedDagModel @@ -34,9 +36,52 @@ @pytest.fixture def test_client(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): + auth_manager = SimpleAuthManager() + # set time_very_before to 2014-01-01 00:00:00 and time_very_after to tomorrow + # to make the JWT token always valid for all test cases with time_machine + time_very_before = datetime.datetime(2014, 1, 1, 0, 0, 0) + time_very_after = datetime.datetime.now() + datetime.timedelta(days=1) + token = auth_manager._get_token_signer().generate_signed_token( + { + "iat": time_very_before, + "nbf": time_very_before, + "exp": time_very_after, + **auth_manager.serialize_user(SimpleAuthManagerUser(username="test", role="admin")), + } + ) + yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"}) + + +@pytest.fixture +def unauthenticated_test_client(): return TestClient(create_app()) +@pytest.fixture +def unauthorized_test_client(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): + auth_manager = SimpleAuthManager() + token = auth_manager._get_token_signer().generate_signed_token( + auth_manager.serialize_user(SimpleAuthManagerUser(username="dummy", role=None)) + ) + yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"}) + + @pytest.fixture def client(): """This fixture is more flexible than test_client, as it allows to specify which apps to include.""" diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 1121c190c7ace..ddc7666bfdc7b 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -65,6 +65,10 @@ def test_delete_should_respond_204(self, test_client, session): pools = session.query(Pool).all() assert len(pools) == 2 + def test_delete_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 401 + def test_delete_should_respond_400(self, test_client): response = test_client.delete("/public/pools/default_pool") assert response.status_code == 400 @@ -96,6 +100,10 @@ def test_get_should_respond_200(self, test_client, session): "slots": 3, } + def test_get_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 401 + def test_get_should_respond_404(self, test_client): response = test_client.get(f"/public/pools/{POOL1_NAME}") assert response.status_code == 404 @@ -134,6 +142,10 @@ def test_should_respond_200( assert body["total_entries"] == expected_total_entries assert [pool["name"] for pool in body["pools"]] == expected_ids + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/pools", params={"pool_name_pattern": "~"}) + assert response.status_code == 401 + class TestPatchPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -277,6 +289,10 @@ def test_should_respond_200( assert body == expected_response + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={}) + assert response.status_code == 401 + class TestPostPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -325,6 +341,10 @@ def test_should_respond_200(self, test_client, session, body, expected_status_co assert response.json() == expected_response assert session.query(Pool).count() == n_pools + 1 + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post("/public/pools", json={}) + assert response.status_code == 401 + @pytest.mark.parametrize( "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", [ @@ -711,3 +731,7 @@ def test_bulk_pools(self, test_client, actions, expected_results, session): response_data = response.json() for key, value in expected_results.items(): assert response_data[key] == value + + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/pools", json={}) + assert response.status_code == 401 From 9df676b580ecfc7ecae8c2128d58d29337a12644 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Fri, 28 Feb 2025 11:03:56 +0530 Subject: [PATCH 2/6] add 403 tests --- .../core_api/routes/public/test_pools.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index ddc7666bfdc7b..c4b4a4a8052b1 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -69,6 +69,10 @@ def test_delete_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.delete(f"/public/pools/{POOL1_NAME}") assert response.status_code == 401 + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 403 + def test_delete_should_respond_400(self, test_client): response = test_client.delete("/public/pools/default_pool") assert response.status_code == 400 @@ -104,6 +108,10 @@ def test_get_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.get(f"/public/pools/{POOL1_NAME}") assert response.status_code == 401 + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 403 + def test_get_should_respond_404(self, test_client): response = test_client.get(f"/public/pools/{POOL1_NAME}") assert response.status_code == 404 @@ -146,6 +154,10 @@ def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.get("/public/pools", params={"pool_name_pattern": "~"}) assert response.status_code == 401 + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/pools", params={"pool_name_pattern": "~"}) + assert response.status_code == 403 + class TestPatchPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -293,6 +305,10 @@ def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={}) assert response.status_code == 401 + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={}) + assert response.status_code == 403 + class TestPostPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -345,6 +361,10 @@ def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.post("/public/pools", json={}) assert response.status_code == 401 + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post("/public/pools", json={}) + assert response.status_code == 403 + @pytest.mark.parametrize( "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", [ @@ -735,3 +755,7 @@ def test_bulk_pools(self, test_client, actions, expected_results, session): def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.patch("/public/pools", json={}) assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/pools", json={}) + assert response.status_code == 403 From 7fff1e1300813bdf4c13da2e2345ca2f901c8a17 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Fri, 28 Feb 2025 18:42:51 +0530 Subject: [PATCH 3/6] fix --- airflow/api_fastapi/core_api/security.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index ecdb9079b6154..84959ef68fcba 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -88,7 +88,7 @@ def inner( user: Annotated[BaseUser | None, Depends(get_user)] = None, ) -> None: def callback(): - return get_auth_manager().is_authorized_dag( + return get_auth_manager().is_authorized_pool( method=method, details=PoolDetails(name=pool_name), user=user ) From ec283dc4adae24cd6a5275f3c99a29d158ee7029 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 3 Mar 2025 15:50:16 +0530 Subject: [PATCH 4/6] fix tests --- tests/api_fastapi/conftest.py | 2 +- tests/api_fastapi/core_api/routes/public/test_assets.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index b9070b199337d..dfbd0b4b42748 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -134,7 +134,6 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_ with dag_maker(dag_id) as dag: for task_number in range(version_number): EmptyOperator(task_id=f"task{task_number + 1}") - dag.sync_to_db() SerializedDagModel.write_dag( dag, bundle_name="dag_maker", bundle_version=f"some_commit_hash{version_number}" ) @@ -143,6 +142,7 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_ logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc), dag_version=DagVersion.get_version(dag_id=dag_id, version_number=version_number), ) + dag.sync_to_db() @pytest.fixture(scope="module") diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py b/tests/api_fastapi/core_api/routes/public/test_assets.py index 764d16b5e8326..c8f148af12406 100644 --- a/tests/api_fastapi/core_api/routes/public/test_assets.py +++ b/tests/api_fastapi/core_api/routes/public/test_assets.py @@ -990,6 +990,7 @@ def create_dags(self, setup, dag_maker, session): EmptyOperator(task_id="task", outlets=assets[2]) with dag_maker(self.DAG_ASSET_NO, schedule=None, session=session): EmptyOperator(task_id="task") + session.commit() @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") def test_should_respond_200(self, test_client): From 404e014be9507cf8581a2ecc17e55dfdf78aa626 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 3 Mar 2025 17:21:15 +0530 Subject: [PATCH 5/6] use Request in requires_access_pool --- .../core_api/openapi/v1-generated.yaml | 38 ++----------------- airflow/api_fastapi/core_api/security.py | 4 +- airflow/ui/openapi-gen/queries/common.ts | 4 +- airflow/ui/openapi-gen/queries/prefetch.ts | 7 +--- airflow/ui/openapi-gen/queries/queries.ts | 22 ++--------- airflow/ui/openapi-gen/queries/suspense.ts | 10 +---- .../ui/openapi-gen/requests/services.gen.ts | 10 ----- airflow/ui/openapi-gen/requests/types.gen.ts | 9 ++--- 8 files changed, 18 insertions(+), 86 deletions(-) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 6af9b97844d99..f3c88ad661870 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4068,9 +4068,7 @@ paths: in: path required: true schema: - anyOf: - - type: string - - type: 'null' + type: string title: Pool Name responses: '204': @@ -4118,9 +4116,7 @@ paths: in: path required: true schema: - anyOf: - - type: string - - type: 'null' + type: string title: Pool Name responses: '200': @@ -4166,9 +4162,7 @@ paths: in: path required: true schema: - anyOf: - - type: string - - type: 'null' + type: string title: Pool Name - name: update_mask in: query @@ -4233,14 +4227,6 @@ paths: security: - OAuth2PasswordBearer: [] parameters: - - name: pool_name - in: query - required: false - schema: - anyOf: - - type: string - - type: 'null' - title: Pool Name - name: limit in: query required: false @@ -4311,15 +4297,6 @@ paths: operationId: post_pool security: - OAuth2PasswordBearer: [] - parameters: - - name: pool_name - in: query - required: false - schema: - anyOf: - - type: string - - type: 'null' - title: Pool Name requestBody: required: true content: @@ -4365,15 +4342,6 @@ paths: operationId: bulk_pools security: - OAuth2PasswordBearer: [] - parameters: - - name: pool_name - in: query - required: false - schema: - anyOf: - - type: string - - type: 'null' - title: Pool Name requestBody: required: true content: diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 84959ef68fcba..4df2f3ccf4781 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -84,9 +84,11 @@ def callback(): def requires_access_pool(method: ResourceMethod) -> Callable: def inner( - pool_name: str | None = None, + request: Request, user: Annotated[BaseUser | None, Depends(get_user)] = None, ) -> None: + pool_name = request.path_params.get("pool_name", "None") + def callback(): return get_auth_manager().is_authorized_pool( method=method, details=PoolDetails(name=pool_name), user=user diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 51dd8f4b24b6d..151e3b100a9cf 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1510,17 +1510,15 @@ export const UsePoolServiceGetPoolsKeyFn = ( limit, offset, orderBy, - poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; - poolName?: string; poolNamePattern?: string; } = {}, queryKey?: Array, -) => [usePoolServiceGetPoolsKey, ...(queryKey ?? [{ limit, offset, orderBy, poolName, poolNamePattern }])]; +) => [usePoolServiceGetPoolsKey, ...(queryKey ?? [{ limit, offset, orderBy, poolNamePattern }])]; export type ProviderServiceGetProvidersDefaultResponse = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index f3c85af050157..c833e4268706b 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -2102,7 +2102,6 @@ export const prefetchUsePoolServiceGetPool = ( * Get Pools * Get all pools entries. * @param data The data for the request. - * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2116,19 +2115,17 @@ export const prefetchUsePoolServiceGetPools = ( limit, offset, orderBy, - poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; - poolName?: string; poolNamePattern?: string; } = {}, ) => queryClient.prefetchQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolName, poolNamePattern }), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }), + queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }), }); /** * Get Providers diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index c46039ce9129c..2fdd58c4ec273 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -2471,7 +2471,6 @@ export const usePoolServiceGetPool = < * Get Pools * Get all pools entries. * @param data The data for the request. - * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2488,24 +2487,19 @@ export const usePoolServiceGetPools = < limit, offset, orderBy, - poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; - poolName?: string; poolNamePattern?: string; } = {}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">, ) => useQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn( - { limit, offset, orderBy, poolName, poolNamePattern }, - queryKey, - ), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }) as TData, + queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }, queryKey), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }) as TData, ...options, }); /** @@ -3340,7 +3334,6 @@ export const useTaskInstanceServicePostClearTaskInstances = < * Create a Pool. * @param data The data for the request. * @param data.requestBody - * @param data.poolName * @returns PoolResponse Successful Response * @throws ApiError */ @@ -3354,7 +3347,6 @@ export const usePoolServicePostPool = < TData, TError, { - poolName?: string; requestBody: PoolBody; }, TContext @@ -3366,13 +3358,11 @@ export const usePoolServicePostPool = < TData, TError, { - poolName?: string; requestBody: PoolBody; }, TContext >({ - mutationFn: ({ poolName, requestBody }) => - PoolService.postPool({ poolName, requestBody }) as unknown as Promise, + mutationFn: ({ requestBody }) => PoolService.postPool({ requestBody }) as unknown as Promise, ...options, }); /** @@ -4151,7 +4141,6 @@ export const usePoolServicePatchPool = < * Bulk create, update, and delete pools. * @param data The data for the request. * @param data.requestBody - * @param data.poolName * @returns BulkResponse Successful Response * @throws ApiError */ @@ -4165,7 +4154,6 @@ export const usePoolServiceBulkPools = < TData, TError, { - poolName?: string; requestBody: BulkBody_PoolBody_; }, TContext @@ -4177,13 +4165,11 @@ export const usePoolServiceBulkPools = < TData, TError, { - poolName?: string; requestBody: BulkBody_PoolBody_; }, TContext >({ - mutationFn: ({ poolName, requestBody }) => - PoolService.bulkPools({ poolName, requestBody }) as unknown as Promise, + mutationFn: ({ requestBody }) => PoolService.bulkPools({ requestBody }) as unknown as Promise, ...options, }); /** diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index d408243fce83b..48b17556fd9ec 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -2448,7 +2448,6 @@ export const usePoolServiceGetPoolSuspense = < * Get Pools * Get all pools entries. * @param data The data for the request. - * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2465,24 +2464,19 @@ export const usePoolServiceGetPoolsSuspense = < limit, offset, orderBy, - poolName, poolNamePattern, }: { limit?: number; offset?: number; orderBy?: string; - poolName?: string; poolNamePattern?: string; } = {}, queryKey?: TQueryKey, options?: Omit, "queryKey" | "queryFn">, ) => useSuspenseQuery({ - queryKey: Common.UsePoolServiceGetPoolsKeyFn( - { limit, offset, orderBy, poolName, poolNamePattern }, - queryKey, - ), - queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolName, poolNamePattern }) as TData, + queryKey: Common.UsePoolServiceGetPoolsKeyFn({ limit, offset, orderBy, poolNamePattern }, queryKey), + queryFn: () => PoolService.getPools({ limit, offset, orderBy, poolNamePattern }) as TData, ...options, }); /** diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 54a947342b5f5..81b649d31cb97 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -2917,7 +2917,6 @@ export class PoolService { * Get Pools * Get all pools entries. * @param data The data for the request. - * @param data.poolName * @param data.limit * @param data.offset * @param data.orderBy @@ -2930,7 +2929,6 @@ export class PoolService { method: "GET", url: "/public/pools", query: { - pool_name: data.poolName, limit: data.limit, offset: data.offset, order_by: data.orderBy, @@ -2950,7 +2948,6 @@ export class PoolService { * Create a Pool. * @param data The data for the request. * @param data.requestBody - * @param data.poolName * @returns PoolResponse Successful Response * @throws ApiError */ @@ -2958,9 +2955,6 @@ export class PoolService { return __request(OpenAPI, { method: "POST", url: "/public/pools", - query: { - pool_name: data.poolName, - }, body: data.requestBody, mediaType: "application/json", errors: { @@ -2977,7 +2971,6 @@ export class PoolService { * Bulk create, update, and delete pools. * @param data The data for the request. * @param data.requestBody - * @param data.poolName * @returns BulkResponse Successful Response * @throws ApiError */ @@ -2985,9 +2978,6 @@ export class PoolService { return __request(OpenAPI, { method: "PATCH", url: "/public/pools", - query: { - pool_name: data.poolName, - }, body: data.requestBody, mediaType: "application/json", errors: { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index cab4f8a3a6bae..2ca718644c5fd 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -2346,19 +2346,19 @@ export type GetPluginsData = { export type GetPluginsResponse = PluginCollectionResponse; export type DeletePoolData = { - poolName: string | null; + poolName: string; }; export type DeletePoolResponse = void; export type GetPoolData = { - poolName: string | null; + poolName: string; }; export type GetPoolResponse = PoolResponse; export type PatchPoolData = { - poolName: string | null; + poolName: string; requestBody: PoolPatchBody; updateMask?: Array | null; }; @@ -2369,21 +2369,18 @@ export type GetPoolsData = { limit?: number; offset?: number; orderBy?: string; - poolName?: string | null; poolNamePattern?: string | null; }; export type GetPoolsResponse = PoolCollectionResponse; export type PostPoolData = { - poolName?: string | null; requestBody: PoolBody; }; export type PostPoolResponse = PoolResponse; export type BulkPoolsData = { - poolName?: string | null; requestBody: BulkBody_PoolBody_; }; From 58543f0d60cfede686b34a77c357cd01e41cd9d1 Mon Sep 17 00:00:00 2001 From: kalyanr Date: Mon, 3 Mar 2025 17:26:03 +0530 Subject: [PATCH 6/6] fix --- airflow/api_fastapi/core_api/security.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 4df2f3ccf4781..6a073c1c61f7a 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -87,7 +87,7 @@ def inner( request: Request, user: Annotated[BaseUser | None, Depends(get_user)] = None, ) -> None: - pool_name = request.path_params.get("pool_name", "None") + pool_name = request.path_params.get("pool_name") def callback(): return get_auth_manager().is_authorized_pool(