From 3782ffab8f42aa181e99a334a9b07f3144567cb4 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 14 Sep 2023 00:42:13 +0200 Subject: [PATCH 1/4] Fix dag warning endpoint permissions --- .../endpoints/dag_warning_endpoint.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index 203bb9d6a5e97..9cb9430b256a0 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -18,9 +18,11 @@ from typing import TYPE_CHECKING +from flask import g from sqlalchemy import select from airflow.api_connexion import security +from airflow.api_connexion.exceptions import PermissionDenied from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.dag_warning_schema import ( DagWarningCollection, @@ -28,7 +30,7 @@ ) from airflow.models.dagwarning import DagWarning as DagWarningModel from airflow.security import permissions -from airflow.utils.db import get_query_count +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -57,12 +59,20 @@ def get_dag_warnings( allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"] query = select(DagWarningModel) if dag_id: + if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user): + raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}") query = query.where(DagWarningModel.dag_id == dag_id) if warning_type: query = query.where(DagWarningModel.warning_type == warning_type) - total_entries = get_query_count(query, session=session) query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs) dag_warnings = session.scalars(query.offset(offset).limit(limit)).all() + if not dag_id: + dag_warnings = [ + dag_warning + for dag_warning in dag_warnings + if get_airflow_app().appbuilder.sm.can_read_dag(dag_warning.dag_id, g.user) + ] + total_entries = len(dag_warnings) return dag_warning_collection_schema.dump( DagWarningCollection(dag_warnings=dag_warnings, total_entries=total_entries) ) From e212aedc65b85671062fb1d4cc2de7be7fa3d449 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 14 Sep 2023 01:15:14 +0200 Subject: [PATCH 2/4] update the query to have an accurate result for total entries and pagination --- .../api_connexion/endpoints/dag_warning_endpoint.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index 9cb9430b256a0..367b0ae104571 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -31,6 +31,7 @@ from airflow.models.dagwarning import DagWarning as DagWarningModel from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app +from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -62,17 +63,14 @@ def get_dag_warnings( if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user): raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}") query = query.where(DagWarningModel.dag_id == dag_id) + else: + readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + query = query.where(DagWarningModel.dag_id.in_(readable_dags)) if warning_type: query = query.where(DagWarningModel.warning_type == warning_type) + total_entries = get_query_count(query, session=session) query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs) dag_warnings = session.scalars(query.offset(offset).limit(limit)).all() - if not dag_id: - dag_warnings = [ - dag_warning - for dag_warning in dag_warnings - if get_airflow_app().appbuilder.sm.can_read_dag(dag_warning.dag_id, g.user) - ] - total_entries = len(dag_warnings) return dag_warning_collection_schema.dump( DagWarningCollection(dag_warnings=dag_warnings, total_entries=total_entries) ) From 658a85307782998ad441e7e31df2bc5074e745f1 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 14 Sep 2023 01:33:55 +0200 Subject: [PATCH 3/4] add unit tests --- .../endpoints/test_dag_warning_endpoint.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 621b04366702b..36891946de906 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -35,14 +35,27 @@ def configured_app(minimal_app_for_api): app, # type:ignore username="test", role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)], # type: ignore + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + ], # type: ignore ) create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_PREFIX + "dag2"), + ], # type: ignore + ) yield minimal_app_for_api delete_user(app, username="test") # type: ignore delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test_with_dag2_read") # type: ignore class TestBaseDagWarning: @@ -147,3 +160,11 @@ def test_should_raise_403_forbidden(self): "/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"} ) assert response.status_code == 403 + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403 From b6510c4de6581f04ff60e839f4340ced99c7ee1a Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 15 Sep 2023 10:25:28 +0200 Subject: [PATCH 4/4] Update test_dag_warning_endpoint.py Co-authored-by: Tzu-ping Chung --- tests/api_connexion/endpoints/test_dag_warning_endpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 36891946de906..041a61634eb0c 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -47,7 +47,7 @@ def configured_app(minimal_app_for_api): role_name="TestWithDag2Read", permissions=[ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_PREFIX + "dag2"), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), ], # type: ignore )