diff --git a/airflow-core/src/airflow/api_fastapi/auth/managers/simple/middleware.py b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/middleware.py new file mode 100644 index 0000000000000..6c73cd015fa9f --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/auth/managers/simple/middleware.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from airflow.api_fastapi.auth.managers.simple.services.login import SimpleAuthManagerLogin + + +class SimpleAllAdminMiddleware(BaseHTTPMiddleware): + """Middleware that automatically generates and includes auth header for simple auth manager.""" + + async def dispatch(self, request: Request, call_next): + # Starlette Request is expected to be immutable, but we modify it to add the auth header + # https://github.com/fastapi/fastapi/issues/2727#issuecomment-770202019 + token = SimpleAuthManagerLogin.create_token_all_admins() + request.scope["headers"].append((b"authorization", f"Bearer {token}".encode())) + return await call_next(request) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/app.py b/airflow-core/src/airflow/api_fastapi/core_api/app.py index 49f522994313a..90327ecc26677 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/app.py @@ -167,4 +167,10 @@ def init_error_handlers(app: FastAPI) -> None: def init_middlewares(app: FastAPI) -> None: + from airflow.configuration import conf + app.add_middleware(FlaskExceptionsMiddleware) + if conf.getboolean("core", "simple_auth_manager_all_admins"): + from airflow.api_fastapi.auth.managers.simple.middleware import SimpleAllAdminMiddleware + + app.add_middleware(SimpleAllAdminMiddleware) diff --git a/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_middleware.py b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_middleware.py new file mode 100644 index 0000000000000..dcd39ef19ed62 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/auth/managers/simple/test_middleware.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from airflow.api_fastapi.app import create_app + +from tests_common.test_utils.config import conf_vars + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def all_access_test_client(): + with conf_vars( + { + ("core", "simple_auth_manager_all_admins"): "true", + ("webserver", "expose_config"): "true", + } + ): + app = create_app() + yield TestClient(app) + + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/api/v2/assets"), + ("POST", "/api/v2/backfills"), + ("GET", "/api/v2/config"), + ("GET", "/api/v2/dags"), + ("POST", "/api/v2/dags/{dag_id}/clearTaskInstances"), + ("GET", "/api/v2/dags/{dag_id}/dagRuns"), + ("GET", "/api/v2/eventLogs"), + ("GET", "/api/v2/jobs"), + ("GET", "/api/v2/variables"), + ("GET", "/api/v2/version"), + ], +) +def test_all_endpoints_without_auth_header(all_access_test_client, method, path): + response = all_access_test_client.request(method, path) + assert response.status_code not in {401, 403}, ( + f"Unexpected status code {response.status_code} for {method} {path}" + )