From c6aaa8cacf560b496f3d6df57e382d8eaef1fcdb Mon Sep 17 00:00:00 2001 From: vincbeck Date: Mon, 12 May 2025 15:59:35 -0400 Subject: [PATCH] Upgrade `flask-appbuilder` to 4.6.3 in FAB provider --- providers/fab/pyproject.toml | 2 +- .../role_and_permission_endpoint.py | 4 ++ .../api_endpoints/user_endpoint.py | 4 ++ .../fab/auth_manager/models/__init__.py | 62 +++++++++++++++++-- .../fab/auth_manager/models/anonymous_user.py | 6 +- .../auth_manager/security_manager/override.py | 46 +++++++------- .../auth_manager/api_endpoints/test_auth.py | 2 - .../models/test_anonymous_user.py | 2 +- .../unit/fab/auth_manager/test_security.py | 12 ++-- 9 files changed, 102 insertions(+), 38 deletions(-) diff --git a/providers/fab/pyproject.toml b/providers/fab/pyproject.toml index e75eaa30eaa3c..6e73985c22a13 100644 --- a/providers/fab/pyproject.toml +++ b/providers/fab/pyproject.toml @@ -71,7 +71,7 @@ dependencies = [ # Every time we update FAB version here, please make sure that you review the classes and models in # `airflow/providers/fab/auth_manager/security_manager/override.py` with their upstream counterparts. # In particular, make sure any breaking changes, for example any new methods, are accounted for. - "flask-appbuilder==4.5.3", + "flask-appbuilder==4.6.3", "flask-login>=0.6.2", # Flask-Session 0.6 add new arguments into the SqlAlchemySessionInterface constructor as well as # all parameters now are mandatory which make AirflowDatabaseSessionInterface incompatible with this version. diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py b/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py index e8aacedea8c83..fa5e29782b9c2 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py @@ -123,6 +123,8 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse """Update a role.""" security_manager = cast("FabAuthManager", get_auth_manager()).security_manager body = request.json + if body is None: + raise BadRequest("Request body is required") try: data = role_schema.load(body) except ValidationError as err: @@ -156,6 +158,8 @@ def post_role() -> APIResponse: """Create a new role.""" security_manager = cast("FabAuthManager", get_auth_manager()).security_manager body = request.json + if body is None: + raise BadRequest("Request body is required") try: data = role_schema.load(body) except ValidationError as err: diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py b/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py index 8c504d6446694..e8f9fc83d9059 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py @@ -88,6 +88,8 @@ def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) -> @requires_access_custom_view("POST", permissions.RESOURCE_USER) def post_user() -> APIResponse: """Create a new user.""" + if request.json is None: + raise BadRequest("Request body is required") try: data = user_schema.load(request.json) except ValidationError as e: @@ -131,6 +133,8 @@ def post_user() -> APIResponse: @requires_access_custom_view("PUT", permissions.RESOURCE_USER) def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse: """Update a user.""" + if request.json is None: + raise BadRequest("Request body is required") try: data = user_schema.load(request.json) except ValidationError as e: diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/__init__.py b/providers/fab/src/airflow/providers/fab/auth_manager/models/__init__.py index 1a34c9b6884eb..cb6f59f6ad576 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/models/__init__.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/__init__.py @@ -102,11 +102,37 @@ def __repr__(self): "ab_permission_view_role", Model.metadata, Column("id", Integer, primary_key=True), - Column("permission_view_id", Integer, ForeignKey("ab_permission_view.id")), - Column("role_id", Integer, ForeignKey("ab_role.id")), + Column( + "permission_view_id", + Integer, + ForeignKey("ab_permission_view.id", ondelete="CASCADE"), + ), + Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")), UniqueConstraint("permission_view_id", "role_id"), ) +assoc_user_group = Table( + "ab_user_group", + Model.metadata, + Column("id", Integer, primary_key=True), + Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")), + Column("group_id", Integer, ForeignKey("ab_group.id", ondelete="CASCADE")), + UniqueConstraint("user_id", "group_id"), + Index("idx_user_id", "user_id"), + Index("idx_user_group_id", "group_id"), +) + +assoc_group_role = Table( + "ab_group_role", + Model.metadata, + Column("id", Integer, primary_key=True), + Column("group_id", Integer, ForeignKey("ab_group.id", ondelete="CASCADE")), + Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")), + UniqueConstraint("group_id", "role_id"), + Index("idx_group_id", "group_id"), + Index("idx_group_role_id", "role_id"), +) + class Role(Model): """Represents a user role to which permissions can be assigned.""" @@ -115,7 +141,29 @@ class Role(Model): id = Column(Integer, primary_key=True) name = Column(String(64), unique=True, nullable=False) - permissions = relationship("Permission", secondary=assoc_permission_role, backref="role", lazy="joined") + permissions = relationship( + "Permission", + secondary=assoc_permission_role, + backref="role", + lazy="joined", + passive_deletes=True, + ) + + def __repr__(self): + return self.name + + +class Group(Model): + """Represents a user group.""" + + __tablename__ = "ab_group" + + id = Column(Integer, primary_key=True) + name = Column(String(100), unique=True, nullable=False) + label = Column(String(150)) + description = Column(String(512)) + users = relationship("User", secondary=assoc_user_group, backref="groups", passive_deletes=True) + roles = relationship("Role", secondary=assoc_group_role, backref="groups", passive_deletes=True) def __repr__(self): return self.name @@ -148,8 +196,8 @@ def __repr__(self): "ab_user_role", Model.metadata, Column("id", Integer, primary_key=True), - Column("user_id", Integer, ForeignKey("ab_user.id")), - Column("role_id", Integer, ForeignKey("ab_role.id")), + Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")), + Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")), UniqueConstraint("user_id", "role_id"), ) @@ -170,7 +218,9 @@ class User(Model, BaseUser): last_login = Column(DateTime) login_count = Column(Integer) fail_login_count = Column(Integer) - roles = relationship("Role", secondary=assoc_user_role, backref="user", lazy="selectin") + roles = relationship( + "Role", secondary=assoc_user_role, backref="user", lazy="selectin", passive_deletes=True + ) created_on = Column(DateTime, default=datetime.datetime.now, nullable=True) changed_on = Column(DateTime, default=datetime.datetime.now, nullable=True) diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/anonymous_user.py b/providers/fab/src/airflow/providers/fab/auth_manager/models/anonymous_user.py index b9abd5f165378..f50341b2304c2 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/models/anonymous_user.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/anonymous_user.py @@ -37,13 +37,17 @@ def roles(self): if not self._roles: public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None) self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set() - return self._roles + return list(self._roles) @roles.setter def roles(self, roles): self._roles = roles self._perms = set() + @property + def groups(self): + return [] + @property def perms(self): if not self._perms: diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index 06bebab9f22be..3e63f7053f40e 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -55,6 +55,7 @@ AuthOIDView, AuthRemoteUserView, RegisterUserModelView, + UserGroupModelView, ) from flask_babel import lazy_gettext from flask_jwt_extended import JWTManager @@ -71,6 +72,7 @@ from airflow.models import DagBag from airflow.providers.fab.auth_manager.models import ( Action, + Group, Permission, RegisterUser, Resource, @@ -100,10 +102,7 @@ from airflow.providers.fab.auth_manager.views.user_stats import CustomUserStatsChartView from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 -from airflow.providers.fab.www.session import ( - AirflowDatabaseSessionInterface, - AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface, -) +from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface from airflow.security.permissions import RESOURCE_BACKFILL if TYPE_CHECKING: @@ -149,6 +148,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): """ Models """ user_model = User role_model = Role + group_model = Group action_model = Action resource_model = Resource permission_model = Permission @@ -173,6 +173,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): actionmodelview = ActionModelView permissionmodelview = PermissionPairModelView rolemodelview = CustomRoleModelView + groupmodelview = UserGroupModelView registeruser_model = RegisterUser registerusermodelview = RegisterUserModelView resourcemodelview = ResourceModelView @@ -450,7 +451,7 @@ def register_views(self): role_view = self.appbuilder.add_view( self.rolemodelview, "List Roles", - icon="fa-group", + icon="fa-user-gear", label=lazy_gettext("List Roles"), category="Security", category_icon="fa-cogs", @@ -532,12 +533,7 @@ def reset_password(self, userid: int, password: str) -> bool: return self.update_user(user) def reset_user_sessions(self, user: User) -> None: - if isinstance( - self.appbuilder.get_app.session_interface, AirflowDatabaseSessionInterface - ) or isinstance( - self.appbuilder.get_app.session_interface, - FabAirflowDatabaseSessionInterface, - ): + if isinstance(self.appbuilder.get_app.session_interface, AirflowDatabaseSessionInterface): interface = self.appbuilder.get_app.session_interface session = interface.db.session user_session_model = interface.sql_session_model @@ -859,6 +855,7 @@ def _init_data_model(self): self.registerusermodelview.datamodel = SQLAInterface(self.registeruser_model) self.rolemodelview.datamodel = SQLAInterface(self.role_model) + self.groupmodelview.datamodel = SQLAInterface(self.group_model) self.actionmodelview.datamodel = SQLAInterface(self.action_model) self.resourcemodelview.datamodel = SQLAInterface(self.resource_model) self.permissionmodelview.datamodel = SQLAInterface(self.permission_model) @@ -875,7 +872,8 @@ def create_db(self): try: engine = self.get_session.get_bind(mapper=None, clause=None) inspector = inspect(engine) - if "ab_user" not in inspector.get_table_names(): + existing_tables = inspector.get_table_names() + if "ab_user" not in existing_tables or "ab_group" not in existing_tables: log.info(const.LOGMSG_INF_SEC_NO_DB) Base.metadata.create_all(engine) log.info(const.LOGMSG_INF_SEC_ADD_DB) @@ -1311,15 +1309,20 @@ def get_public_role(self): def add_user( self, - username, - first_name, - last_name, - email, - role, - password="", - hashed_password="", + username: str, + first_name: str, + last_name: str, + email: str, + role: list[Role] | Role | None = None, + password: str = "", + hashed_password: str = "", + groups: list[Group] | None = None, ): """Create a user.""" + roles: list[Role] = [] + if role: + roles = role if isinstance(role, list) else [role] + try: user = self.user_model() user.first_name = first_name @@ -1328,7 +1331,8 @@ def add_user( user.email = email user.active = True self.get_session.add(user) - user.roles = role if isinstance(role, list) else [role] + user.roles = roles + user.groups = groups or [] if hashed_password: user.password = hashed_password else: @@ -1692,7 +1696,7 @@ def get_user_roles(user=None): """ if user is None: user = g.user - return user.roles + return user.roles + [role for group in user.groups for role in group.roles] """ -------------------- diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py index 63b88495715f6..c59c7dcf78221 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py +++ b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py @@ -19,7 +19,6 @@ from base64 import b64encode import pytest -from flask_login import current_user from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_pools @@ -74,6 +73,5 @@ def test_success(self): with self.app.test_client() as test_client: response = test_client.get("/fab/v1/users", headers={"Authorization": token}) - assert current_user.email == "test@fab.org" assert response.status_code == 200 diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_anonymous_user.py b/providers/fab/tests/unit/fab/auth_manager/models/test_anonymous_user.py index eaf6b357f9264..ea959bd25bfb2 100644 --- a/providers/fab/tests/unit/fab/auth_manager/models/test_anonymous_user.py +++ b/providers/fab/tests/unit/fab/auth_manager/models/test_anonymous_user.py @@ -25,7 +25,7 @@ class TestAnonymousUser: def test_roles(self): - roles = {"role1"} + roles = ["role1"] user = AnonymousUser() user.roles = roles assert user.roles == roles diff --git a/providers/fab/tests/unit/fab/auth_manager/test_security.py b/providers/fab/tests/unit/fab/auth_manager/test_security.py index 17f5f562b1134..6285dff9dc300 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_security.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_security.py @@ -353,7 +353,7 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids( mock_is_logged_in.return_value = False user = AnonymousUser() app.config["AUTH_ROLE_PUBLIC"] = "Public" - assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} + assert security_manager.get_user_roles(user) == [security_manager.get_public_role()] with _create_dag_model_context("test_dag_id", session, security_manager): security_manager.sync_roles() @@ -365,7 +365,7 @@ def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, se with app.app_context(): user = AnonymousUser() app.config["AUTH_ROLE_PUBLIC"] = "Public" - assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} + assert security_manager.get_user_roles(user) == [security_manager.get_public_role()] dag_id = "test_dag_id" with _create_dag_model_context(dag_id, session, security_manager): @@ -392,7 +392,7 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access( mock_is_logged_in.return_value = False user = AnonymousUser() - assert security_manager.get_user_roles(user) == {security_manager.get_public_role()} + assert security_manager.get_user_roles(user) == [security_manager.get_public_role()] security_manager.sync_roles() @@ -408,7 +408,7 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag( # Call `.get_user_roles` bc `user` is a mock and the `user.roles` prop needs to be set. user.roles = security_manager.get_user_roles(user) - assert user.roles == {security_manager.get_public_role()} + assert user.roles == [security_manager.get_public_role()] test_dag_ids = ["test_dag_id_1", "test_dag_id_2", "test_dag_id_3", "test_dag_id_4.with_dot"] @@ -425,8 +425,8 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag( def test_get_user_roles(app_builder, security_manager): user = mock.MagicMock() roles = app_builder.sm.find_role("Admin") - user.roles = roles - assert security_manager.get_user_roles(user) == roles + user.roles = [roles] + assert security_manager.get_user_roles(user) == [roles] def test_get_user_roles_for_anonymous_user(app, security_manager):