diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index 203bb9d6a5e97..367b0ae104571 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -18,9 +18,11 @@ from typing import TYPE_CHECKING +from flask import g from sqlalchemy import select from airflow.api_connexion import security +from airflow.api_connexion.exceptions import PermissionDenied from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.dag_warning_schema import ( DagWarningCollection, @@ -28,6 +30,7 @@ ) from airflow.models.dagwarning import DagWarning as DagWarningModel from airflow.security import permissions +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -57,7 +60,12 @@ def get_dag_warnings( allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"] query = select(DagWarningModel) if dag_id: + if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user): + raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}") query = query.where(DagWarningModel.dag_id == dag_id) + else: + readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + query = query.where(DagWarningModel.dag_id.in_(readable_dags)) if warning_type: query = query.where(DagWarningModel.warning_type == warning_type) total_entries = get_query_count(query, session=session) diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 621b04366702b..041a61634eb0c 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -35,14 +35,27 @@ def configured_app(minimal_app_for_api): app, # type:ignore username="test", role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)], # type: ignore + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + ], # type: ignore ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), + ], # type: ignore + ) yield minimal_app_for_api delete_user(app, username="test") # type: ignore delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test_with_dag2_read") # type: ignore class TestBaseDagWarning: @@ -147,3 +160,11 @@ def test_should_raise_403_forbidden(self): "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403