Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@

from typing import TYPE_CHECKING

from flask import g
from sqlalchemy import select

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import PermissionDenied
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.dag_warning_schema import (
DagWarningCollection,
dag_warning_collection_schema,
)
from airflow.models.dagwarning import DagWarning as DagWarningModel
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session

Expand Down Expand Up @@ -57,7 +60,12 @@ def get_dag_warnings(
allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
query = select(DagWarningModel)
if dag_id:
if not get_airflow_app().appbuilder.sm.can_read_dag(dag_id, g.user):
Copy link
Contributor

@ephraimbuddy ephraimbuddy Sep 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should have the DAG permission in the decorator instead of handling this by ourselves. My vote would go to just adding (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG) to the list of permissions.
That should solve this. Assuming it's dag_ids instead of dag_id, then we can do as you're doing now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see that if dag_id is not provided, then it gets all dag warning, my bad.

raise PermissionDenied(detail=f"User not allowed to access this DAG: {dag_id}")
query = query.where(DagWarningModel.dag_id == dag_id)
else:
readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
query = query.where(DagWarningModel.dag_id.in_(readable_dags))
if warning_type:
query = query.where(DagWarningModel.warning_type == warning_type)
total_entries = get_query_count(query, session=session)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless query, the count could be calculated from the list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually not useless, because the is limit in the query, I will fix it

Expand Down
23 changes: 22 additions & 1 deletion tests/api_connexion/endpoints/test_dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,27 @@ def configured_app(minimal_app_for_api):
app, # type:ignore
username="test",
role_name="Test",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)], # type: ignore
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
], # type: ignore
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
create_user(
app, # type:ignore
username="test_with_dag2_read",
role_name="TestWithDag2Read",
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
(permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"),
], # type: ignore
)

yield minimal_app_for_api

delete_user(app, username="test") # type: ignore
delete_user(app, username="test_no_permissions") # type: ignore
delete_user(app, username="test_with_dag2_read") # type: ignore


class TestBaseDagWarning:
Expand Down Expand Up @@ -147,3 +160,11 @@ def test_should_raise_403_forbidden(self):
"/api/v1/dagWarnings", environ_overrides={"REMOTE_USER": "test_no_permissions"}
)
assert response.status_code == 403

def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self):
response = self.client.get(
"/api/v1/dagWarnings",
environ_overrides={"REMOTE_USER": "test_with_dag2_read"},
query_string={"dag_id": "dag1"},
)
assert response.status_code == 403