diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 452b4cbc97412..f3c88ad661870 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -4061,6 +4061,8 @@ paths: summary: Delete Pool description: Delete a pool entry. operationId: delete_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path @@ -4107,6 +4109,8 @@ paths: summary: Get Pool description: Get a pool. operationId: get_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path @@ -4151,6 +4155,8 @@ paths: summary: Patch Pool description: Update a Pool. operationId: patch_pool + security: + - OAuth2PasswordBearer: [] parameters: - name: pool_name in: path @@ -4218,6 +4224,8 @@ paths: summary: Get Pools description: Get all pools entries. operationId: get_pools + security: + - OAuth2PasswordBearer: [] parameters: - name: limit in: query @@ -4287,6 +4295,8 @@ paths: summary: Post Pool description: Create a Pool. operationId: post_pool + security: + - OAuth2PasswordBearer: [] requestBody: required: true content: @@ -4330,6 +4340,8 @@ paths: summary: Bulk Pools description: Bulk create, update, and delete pools. operationId: bulk_pools + security: + - OAuth2PasswordBearer: [] requestBody: required: true content: @@ -11306,3 +11318,10 @@ components: - value title: XComUpdateBody description: Payload serializer for updating an XCom entry. + securitySchemes: + OAuth2PasswordBearer: + type: oauth2 + flows: + password: + scopes: {} + tokenUrl: token diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 67079e5542acb..7fd40fb7edd49 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -40,6 +40,7 @@ PoolResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import requires_access_pool from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService from airflow.models.pool import Pool @@ -55,6 +56,7 @@ status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_pool(method="DELETE"))], ) def delete_pool( pool_name: str, @@ -73,6 +75,7 @@ def delete_pool( @pools_router.get( "/{pool_name}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_pool(method="GET"))], ) def get_pool( pool_name: str, @@ -89,6 +92,7 @@ def get_pool( @pools_router.get( "", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_pool(method="GET"))], ) def get_pools( limit: QueryLimit, @@ -126,6 +130,7 @@ def get_pools( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_pool(method="PUT"))], ) def patch_pool( pool_name: str, @@ -177,6 +182,7 @@ def patch_pool( responses=create_openapi_http_exception_doc( [status.HTTP_409_CONFLICT] ), # handled by global exception handler + dependencies=[Depends(requires_access_pool(method="POST"))], ) def post_pool( body: PoolBody, @@ -188,7 +194,10 @@ def post_pool( return pool -@pools_router.patch("") +@pools_router.patch( + "", + dependencies=[Depends(requires_access_pool(method="PUT"))], +) def bulk_pools( request: BulkBody[PoolBody], session: SessionDep, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index eac282e7e3cab..6a073c1c61f7a 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -25,7 +25,7 @@ from airflow.api_fastapi.app import get_auth_manager from airflow.auth.managers.models.base_user import BaseUser -from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails, PoolDetails from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner, get_signing_key @@ -82,6 +82,25 @@ def callback(): return inner +def requires_access_pool(method: ResourceMethod) -> Callable: + def inner( + request: Request, + user: Annotated[BaseUser | None, Depends(get_user)] = None, + ) -> None: + pool_name = request.path_params.get("pool_name") + + def callback(): + return get_auth_manager().is_authorized_pool( + method=method, details=PoolDetails(name=pool_name), user=user + ) + + _requires_access( + is_authorized_callback=callback, + ) + + return inner + + def _requires_access( *, is_authorized_callback: Callable[[], bool], diff --git a/tests/api_fastapi/conftest.py b/tests/api_fastapi/conftest.py index 3e5b44651f0dd..dfbd0b4b42748 100644 --- a/tests/api_fastapi/conftest.py +++ b/tests/api_fastapi/conftest.py @@ -23,6 +23,8 @@ from fastapi.testclient import TestClient from airflow.api_fastapi.app import create_app +from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager +from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.models import Connection from airflow.models.dag_version import DagVersion from airflow.models.serialized_dag import SerializedDagModel @@ -34,9 +36,52 @@ @pytest.fixture def test_client(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): + auth_manager = SimpleAuthManager() + # set time_very_before to 2014-01-01 00:00:00 and time_very_after to tomorrow + # to make the JWT token always valid for all test cases with time_machine + time_very_before = datetime.datetime(2014, 1, 1, 0, 0, 0) + time_very_after = datetime.datetime.now() + datetime.timedelta(days=1) + token = auth_manager._get_token_signer().generate_signed_token( + { + "iat": time_very_before, + "nbf": time_very_before, + "exp": time_very_after, + **auth_manager.serialize_user(SimpleAuthManagerUser(username="test", role="admin")), + } + ) + yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"}) + + +@pytest.fixture +def unauthenticated_test_client(): return TestClient(create_app()) +@pytest.fixture +def unauthorized_test_client(): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): + auth_manager = SimpleAuthManager() + token = auth_manager._get_token_signer().generate_signed_token( + auth_manager.serialize_user(SimpleAuthManagerUser(username="dummy", role=None)) + ) + yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"}) + + @pytest.fixture def client(): """This fixture is more flexible than test_client, as it allows to specify which apps to include.""" @@ -89,7 +134,6 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_ with dag_maker(dag_id) as dag: for task_number in range(version_number): EmptyOperator(task_id=f"task{task_number + 1}") - dag.sync_to_db() SerializedDagModel.write_dag( dag, bundle_name="dag_maker", bundle_version=f"some_commit_hash{version_number}" ) @@ -98,6 +142,7 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_ logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc), dag_version=DagVersion.get_version(dag_id=dag_id, version_number=version_number), ) + dag.sync_to_db() @pytest.fixture(scope="module") diff --git a/tests/api_fastapi/core_api/routes/public/test_assets.py b/tests/api_fastapi/core_api/routes/public/test_assets.py index 764d16b5e8326..c8f148af12406 100644 --- a/tests/api_fastapi/core_api/routes/public/test_assets.py +++ b/tests/api_fastapi/core_api/routes/public/test_assets.py @@ -990,6 +990,7 @@ def create_dags(self, setup, dag_maker, session): EmptyOperator(task_id="task", outlets=assets[2]) with dag_maker(self.DAG_ASSET_NO, schedule=None, session=session): EmptyOperator(task_id="task") + session.commit() @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") def test_should_respond_200(self, test_client): diff --git a/tests/api_fastapi/core_api/routes/public/test_pools.py b/tests/api_fastapi/core_api/routes/public/test_pools.py index 1121c190c7ace..c4b4a4a8052b1 100644 --- a/tests/api_fastapi/core_api/routes/public/test_pools.py +++ b/tests/api_fastapi/core_api/routes/public/test_pools.py @@ -65,6 +65,14 @@ def test_delete_should_respond_204(self, test_client, session): pools = session.query(Pool).all() assert len(pools) == 2 + def test_delete_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 403 + def test_delete_should_respond_400(self, test_client): response = test_client.delete("/public/pools/default_pool") assert response.status_code == 400 @@ -96,6 +104,14 @@ def test_get_should_respond_200(self, test_client, session): "slots": 3, } + def test_get_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/pools/{POOL1_NAME}") + assert response.status_code == 403 + def test_get_should_respond_404(self, test_client): response = test_client.get(f"/public/pools/{POOL1_NAME}") assert response.status_code == 404 @@ -134,6 +150,14 @@ def test_should_respond_200( assert body["total_entries"] == expected_total_entries assert [pool["name"] for pool in body["pools"]] == expected_ids + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/pools", params={"pool_name_pattern": "~"}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/pools", params={"pool_name_pattern": "~"}) + assert response.status_code == 403 + class TestPatchPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -277,6 +301,14 @@ def test_should_respond_200( assert body == expected_response + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={}) + assert response.status_code == 403 + class TestPostPool(TestPoolsEndpoint): @pytest.mark.parametrize( @@ -325,6 +357,14 @@ def test_should_respond_200(self, test_client, session, body, expected_status_co assert response.json() == expected_response assert session.query(Pool).count() == n_pools + 1 + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post("/public/pools", json={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post("/public/pools", json={}) + assert response.status_code == 403 + @pytest.mark.parametrize( "body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response", [ @@ -711,3 +751,11 @@ def test_bulk_pools(self, test_client, actions, expected_results, session): response_data = response.json() for key, value in expected_results.items(): assert response_data[key] == value + + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/pools", json={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/pools", json={}) + assert response.status_code == 403