Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4061,6 +4061,8 @@ paths:
summary: Delete Pool
description: Delete a pool entry.
operationId: delete_pool
security:
- OAuth2PasswordBearer: []
parameters:
- name: pool_name
in: path
Expand Down Expand Up @@ -4107,6 +4109,8 @@ paths:
summary: Get Pool
description: Get a pool.
operationId: get_pool
security:
- OAuth2PasswordBearer: []
parameters:
- name: pool_name
in: path
Expand Down Expand Up @@ -4151,6 +4155,8 @@ paths:
summary: Patch Pool
description: Update a Pool.
operationId: patch_pool
security:
- OAuth2PasswordBearer: []
parameters:
- name: pool_name
in: path
Expand Down Expand Up @@ -4218,6 +4224,8 @@ paths:
summary: Get Pools
description: Get all pools entries.
operationId: get_pools
security:
- OAuth2PasswordBearer: []
parameters:
- name: limit
in: query
Expand Down Expand Up @@ -4287,6 +4295,8 @@ paths:
summary: Post Pool
description: Create a Pool.
operationId: post_pool
security:
- OAuth2PasswordBearer: []
requestBody:
required: true
content:
Expand Down Expand Up @@ -4330,6 +4340,8 @@ paths:
summary: Bulk Pools
description: Bulk create, update, and delete pools.
operationId: bulk_pools
security:
- OAuth2PasswordBearer: []
requestBody:
required: true
content:
Expand Down Expand Up @@ -11306,3 +11318,10 @@ components:
- value
title: XComUpdateBody
description: Payload serializer for updating an XCom entry.
securitySchemes:
OAuth2PasswordBearer:
type: oauth2
flows:
password:
scopes: {}
tokenUrl: token
11 changes: 10 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
PoolResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import requires_access_pool
from airflow.api_fastapi.core_api.services.public.pools import BulkPoolService
from airflow.models.pool import Pool

Expand All @@ -55,6 +56,7 @@
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_pool(method="DELETE"))],
)
def delete_pool(
pool_name: str,
Expand All @@ -73,6 +75,7 @@ def delete_pool(
@pools_router.get(
"/{pool_name}",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_pool(method="GET"))],
)
def get_pool(
pool_name: str,
Expand All @@ -89,6 +92,7 @@ def get_pool(
@pools_router.get(
"",
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_pool(method="GET"))],
)
def get_pools(
limit: QueryLimit,
Expand Down Expand Up @@ -126,6 +130,7 @@ def get_pools(
status.HTTP_404_NOT_FOUND,
]
),
dependencies=[Depends(requires_access_pool(method="PUT"))],
)
def patch_pool(
pool_name: str,
Expand Down Expand Up @@ -177,6 +182,7 @@ def patch_pool(
responses=create_openapi_http_exception_doc(
[status.HTTP_409_CONFLICT]
), # handled by global exception handler
dependencies=[Depends(requires_access_pool(method="POST"))],
)
def post_pool(
body: PoolBody,
Expand All @@ -188,7 +194,10 @@ def post_pool(
return pool


@pools_router.patch("")
@pools_router.patch(
"",
dependencies=[Depends(requires_access_pool(method="PUT"))],
)
def bulk_pools(
request: BulkBody[PoolBody],
session: SessionDep,
Expand Down
21 changes: 20 additions & 1 deletion airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails, PoolDetails
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner, get_signing_key

Expand Down Expand Up @@ -82,6 +82,25 @@ def callback():
return inner


def requires_access_pool(method: ResourceMethod) -> Callable:
def inner(
request: Request,
user: Annotated[BaseUser | None, Depends(get_user)] = None,
) -> None:
pool_name = request.path_params.get("pool_name")

def callback():
return get_auth_manager().is_authorized_pool(
method=method, details=PoolDetails(name=pool_name), user=user
)

_requires_access(
is_authorized_callback=callback,
)

return inner


def _requires_access(
*,
is_authorized_callback: Callable[[], bool],
Expand Down
47 changes: 46 additions & 1 deletion tests/api_fastapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import create_app
from airflow.auth.managers.simple.simple_auth_manager import SimpleAuthManager
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.models import Connection
from airflow.models.dag_version import DagVersion
from airflow.models.serialized_dag import SerializedDagModel
Expand All @@ -34,9 +36,52 @@

@pytest.fixture
def test_client():
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
}
):
auth_manager = SimpleAuthManager()
# set time_very_before to 2014-01-01 00:00:00 and time_very_after to tomorrow
# to make the JWT token always valid for all test cases with time_machine
time_very_before = datetime.datetime(2014, 1, 1, 0, 0, 0)
time_very_after = datetime.datetime.now() + datetime.timedelta(days=1)
token = auth_manager._get_token_signer().generate_signed_token(
{
"iat": time_very_before,
"nbf": time_very_before,
"exp": time_very_after,
**auth_manager.serialize_user(SimpleAuthManagerUser(username="test", role="admin")),
}
)
yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"})


@pytest.fixture
def unauthenticated_test_client():
return TestClient(create_app())


@pytest.fixture
def unauthorized_test_client():
with conf_vars(
{
(
"core",
"auth_manager",
): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
}
):
auth_manager = SimpleAuthManager()
token = auth_manager._get_token_signer().generate_signed_token(
auth_manager.serialize_user(SimpleAuthManagerUser(username="dummy", role=None))
)
yield TestClient(create_app(), headers={"Authorization": f"Bearer {token}"})


@pytest.fixture
def client():
"""This fixture is more flexible than test_client, as it allows to specify which apps to include."""
Expand Down Expand Up @@ -89,7 +134,6 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_
with dag_maker(dag_id) as dag:
for task_number in range(version_number):
EmptyOperator(task_id=f"task{task_number + 1}")
dag.sync_to_db()
SerializedDagModel.write_dag(
dag, bundle_name="dag_maker", bundle_version=f"some_commit_hash{version_number}"
)
Expand All @@ -98,6 +142,7 @@ def make_dag_with_multiple_versions(dag_maker, configure_git_connection_for_dag_
logical_date=datetime.datetime(2020, 1, version_number, tzinfo=datetime.timezone.utc),
dag_version=DagVersion.get_version(dag_id=dag_id, version_number=version_number),
)
dag.sync_to_db()


@pytest.fixture(scope="module")
Expand Down
1 change: 1 addition & 0 deletions tests/api_fastapi/core_api/routes/public/test_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def create_dags(self, setup, dag_maker, session):
EmptyOperator(task_id="task", outlets=assets[2])
with dag_maker(self.DAG_ASSET_NO, schedule=None, session=session):
EmptyOperator(task_id="task")
session.commit()

@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
def test_should_respond_200(self, test_client):
Expand Down
48 changes: 48 additions & 0 deletions tests/api_fastapi/core_api/routes/public/test_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def test_delete_should_respond_204(self, test_client, session):
pools = session.query(Pool).all()
assert len(pools) == 2

def test_delete_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.delete(f"/public/pools/{POOL1_NAME}")
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.delete(f"/public/pools/{POOL1_NAME}")
assert response.status_code == 403

def test_delete_should_respond_400(self, test_client):
response = test_client.delete("/public/pools/default_pool")
assert response.status_code == 400
Expand Down Expand Up @@ -96,6 +104,14 @@ def test_get_should_respond_200(self, test_client, session):
"slots": 3,
}

def test_get_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.get(f"/public/pools/{POOL1_NAME}")
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.get(f"/public/pools/{POOL1_NAME}")
assert response.status_code == 403

def test_get_should_respond_404(self, test_client):
response = test_client.get(f"/public/pools/{POOL1_NAME}")
assert response.status_code == 404
Expand Down Expand Up @@ -134,6 +150,14 @@ def test_should_respond_200(
assert body["total_entries"] == expected_total_entries
assert [pool["name"] for pool in body["pools"]] == expected_ids

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.get("/public/pools", params={"pool_name_pattern": "~"})
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.get("/public/pools", params={"pool_name_pattern": "~"})
assert response.status_code == 403


class TestPatchPool(TestPoolsEndpoint):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -277,6 +301,14 @@ def test_should_respond_200(

assert body == expected_response

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={})
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.patch(f"/public/pools/{POOL1_NAME}", params={}, json={})
assert response.status_code == 403


class TestPostPool(TestPoolsEndpoint):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -325,6 +357,14 @@ def test_should_respond_200(self, test_client, session, body, expected_status_co
assert response.json() == expected_response
assert session.query(Pool).count() == n_pools + 1

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.post("/public/pools", json={})
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.post("/public/pools", json={})
assert response.status_code == 403

@pytest.mark.parametrize(
"body,first_expected_status_code, first_expected_response, second_expected_status_code, second_expected_response",
[
Expand Down Expand Up @@ -711,3 +751,11 @@ def test_bulk_pools(self, test_client, actions, expected_results, session):
response_data = response.json()
for key, value in expected_results.items():
assert response_data[key] == value

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch("/public/pools", json={})
assert response.status_code == 401

def test_should_respond_403(self, unauthorized_test_client):
response = unauthorized_test_client.patch("/public/pools", json={})
assert response.status_code == 403