diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 10662e5ab21c3..996d5bc9fe60e 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -545,11 +545,7 @@ def _serialize_dag_capturing_errors(dag, session): session=session, ) if dag_was_updated: - self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id) - from airflow.www.security import ApplessAirflowSecurityManager - - security_manager = ApplessAirflowSecurityManager(session=session) - security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) + self._sync_perm_for_dag(dag, session=session) return [] except OperationalError: raise @@ -580,3 +576,31 @@ def _serialize_dag_capturing_errors(dag, session): # Only now we are "complete" do we update import_errors - don't want to record errors from # previous failed attempts self.import_errors.update(dict(serialize_errors)) + + @provide_session + def _sync_perm_for_dag(self, dag, session: Optional[Session] = None): + """Sync DAG specific permissions, if necessary""" + from flask_appbuilder.security.sqla import models as sqla_models + + from airflow.security.permissions import DAG_PERMS, permission_name_for_dag + + def needs_perm_views(dag_id: str) -> bool: + view_menu_name = permission_name_for_dag(dag_id) + for permission_name in DAG_PERMS: + if not ( + session.query(sqla_models.PermissionView) + .join(sqla_models.Permission) + .join(sqla_models.ViewMenu) + .filter(sqla_models.Permission.name == permission_name) + .filter(sqla_models.ViewMenu.name == view_menu_name) + .one_or_none() + ): + return True + return False + + if dag.access_control or needs_perm_views(dag.dag_id): + self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id) + from airflow.www.security import ApplessAirflowSecurityManager + + security_manager = ApplessAirflowSecurityManager(session=session) + security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control) diff --git a/airflow/security/permissions.py b/airflow/security/permissions.py index 8e09c750a3ede..f52a1c3dca6b1 100644 --- a/airflow/security/permissions.py +++ b/airflow/security/permissions.py @@ -59,3 +59,15 @@ ACTION_CAN_ACCESS_MENU = "menu_access" DEPRECATED_ACTION_CAN_DAG_READ = "can_dag_read" DEPRECATED_ACTION_CAN_DAG_EDIT = "can_dag_edit" + +DAG_PERMS = {ACTION_CAN_READ, ACTION_CAN_EDIT} + + +def permission_name_for_dag(dag_id): + """Returns the permission name for a DAG id.""" + if dag_id == RESOURCE_DAG: + return dag_id + + if dag_id.startswith(RESOURCE_DAG_PREFIX): + return dag_id + return f"{RESOURCE_DAG_PREFIX}{dag_id}" diff --git a/airflow/www/security.py b/airflow/www/security.py index dcccf47233ef3..bfe1e7973d133 100644 --- a/airflow/www/security.py +++ b/airflow/www/security.py @@ -143,7 +143,7 @@ class AirflowSecurityManager(SecurityManager, LoggingMixin): # pylint: disable= DAG_VMS = {permissions.RESOURCE_DAG} READ_DAG_PERMS = {permissions.ACTION_CAN_READ} - DAG_PERMS = {permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT} + DAG_PERMS = permissions.DAG_PERMS ########################################################################### # DEFAULT ROLE CONFIGURATIONS @@ -362,12 +362,7 @@ def can_edit_dag(self, dag_id, user=None) -> bool: def prefixed_dag_id(self, dag_id): """Returns the permission name for a DAG id.""" - if dag_id == permissions.RESOURCE_DAG: - return dag_id - - if dag_id.startswith(permissions.RESOURCE_DAG_PREFIX): - return dag_id - return f"{permissions.RESOURCE_DAG_PREFIX}{dag_id}" + return permissions.permission_name_for_dag(dag_id) def is_dag_resource(self, resource_name): """Determines if a permission belongs to a DAG or all DAGs.""" diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 2a0d589aff7ac..ae03db9cc604a 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -35,11 +35,13 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.utils.dates import timezone as tz from airflow.utils.session import create_session +from airflow.www.security import ApplessAirflowSecurityManager from tests import cluster_policies from tests.models import TEST_DAGS_FOLDER from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars +from tests.test_utils.permissions import delete_dag_specific_permissions class TestDagBag(unittest.TestCase): @@ -688,41 +690,74 @@ def test_sync_to_db_is_retried(self, mock_bulk_write_to_db, mock_s10n_write_dag, ) @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) - @patch("airflow.www.security.ApplessAirflowSecurityManager") - def test_sync_to_db_handles_dag_specific_permissions(self, mock_security_manager): + @freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0), as_kwarg="frozen_time") + def test_sync_to_db_syncs_dag_specific_perms_on_update(self, frozen_time): """ - Test that when dagbag.sync_to_db is called new DAGs and updates DAGs have their - DAG specific permissions synced + Test that dagbag.sync_to_db will sync DAG specific permissions when a DAG is + new or updated """ with create_session() as session: - # New DAG dagbag = DagBag( dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), include_examples=False, ) - with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 0)): + mock_sync_perm_for_dag = mock.MagicMock() + dagbag._sync_perm_for_dag = mock_sync_perm_for_dag + + def _sync_to_db(): + mock_sync_perm_for_dag.reset_mock() + frozen_time.tick(20) dagbag.sync_to_db(session=session) - mock_security_manager.return_value.sync_perm_for_dag.assert_called_once_with( - "test_example_bash_operator", None - ) + dag = dagbag.dags["test_example_bash_operator"] + _sync_to_db() + mock_sync_perm_for_dag.assert_called_once_with(dag, session=session) + + # DAG isn't updated + _sync_to_db() + mock_sync_perm_for_dag.assert_not_called() # DAG is updated - mock_security_manager.reset_mock() - dagbag.dags["test_example_bash_operator"].tags = ["new_tag"] - with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 20)): - dagbag.sync_to_db(session=session) + dag.tags = ["new_tag"] + _sync_to_db() + mock_sync_perm_for_dag.assert_called_once_with(dag, session=session) - mock_security_manager.return_value.sync_perm_for_dag.assert_called_once_with( - "test_example_bash_operator", None + @patch("airflow.www.security.ApplessAirflowSecurityManager") + def test_sync_perm_for_dag(self, mock_security_manager): + """ + Test that dagbag._sync_perm_for_dag will call ApplessAirflowSecurityManager.sync_perm_for_dag + when DAG specific perm views don't exist already or the DAG has access_control set. + """ + delete_dag_specific_permissions() + with create_session() as session: + security_manager = ApplessAirflowSecurityManager(session) + mock_sync_perm_for_dag = mock_security_manager.return_value.sync_perm_for_dag + mock_sync_perm_for_dag.side_effect = security_manager.sync_perm_for_dag + + dagbag = DagBag( + dag_folder=os.path.join(TEST_DAGS_FOLDER, "test_example_bash_operator.py"), + include_examples=False, ) + dag = dagbag.dags["test_example_bash_operator"] - # DAG isn't updated - mock_security_manager.reset_mock() - with freeze_time(tz.datetime(2020, 1, 5, 0, 0, 40)): - dagbag.sync_to_db(session=session) + def _sync_perms(): + mock_sync_perm_for_dag.reset_mock() + dagbag._sync_perm_for_dag(dag, session=session) + + # permviews dont exist + _sync_perms() + mock_sync_perm_for_dag.assert_called_once_with("test_example_bash_operator", None) - mock_security_manager.return_value.sync_perm_for_dag.assert_not_called() + # permviews now exist + _sync_perms() + mock_sync_perm_for_dag.assert_not_called() + + # Always sync if we have access_control + dag.access_control = {"Public": {"can_read"}} + _sync_perms() + mock_sync_perm_for_dag.assert_called_once_with( + "test_example_bash_operator", {"Public": {"can_read"}} + ) @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_UPDATE_INTERVAL", 5) @patch("airflow.models.dagbag.settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL", 5) diff --git a/tests/test_utils/permissions.py b/tests/test_utils/permissions.py new file mode 100644 index 0000000000000..71ec2e19ab74f --- /dev/null +++ b/tests/test_utils/permissions.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from flask_appbuilder.security.sqla import models as sqla_models + +from airflow.security.permissions import RESOURCE_DAG_PREFIX +from airflow.utils.session import create_session + + +def delete_dag_specific_permissions(): + with create_session() as session: + dag_vms = ( + session.query(sqla_models.ViewMenu) + .filter(sqla_models.ViewMenu.name.like(f"{RESOURCE_DAG_PREFIX}%")) + .all() + ) + vm_ids = [d.id for d in dag_vms] + + dag_pvms = ( + session.query(sqla_models.PermissionView) + .filter(sqla_models.PermissionView.view_menu_id.in_(vm_ids)) + .all() + ) + pvm_ids = [d.id for d in dag_pvms] + + session.query(sqla_models.assoc_permissionview_role).filter( + sqla_models.assoc_permissionview_role.c.permission_view_id.in_(pvm_ids) + ).delete(synchronize_session=False) + session.query(sqla_models.PermissionView).filter( + sqla_models.PermissionView.view_menu_id.in_(vm_ids) + ).delete(synchronize_session=False) + session.query(sqla_models.ViewMenu).filter(sqla_models.ViewMenu.id.in_(vm_ids)).delete( + synchronize_session=False + )