Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/fab/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ dependencies = [
# Every time we update FAB version here, please make sure that you review the classes and models in
# `airflow/providers/fab/auth_manager/security_manager/override.py` with their upstream counterparts.
# In particular, make sure any breaking changes, for example any new methods, are accounted for.
"flask-appbuilder==4.5.3",
"flask-appbuilder==4.6.3",
"flask-login>=0.6.2",
# Flask-Session 0.6 add new arguments into the SqlAlchemySessionInterface constructor as well as
# all parameters now are mandatory which make AirflowDatabaseSessionInterface incompatible with this version.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse
"""Update a role."""
security_manager = cast("FabAuthManager", get_auth_manager()).security_manager
body = request.json
if body is None:
raise BadRequest("Request body is required")
try:
data = role_schema.load(body)
except ValidationError as err:
Expand Down Expand Up @@ -156,6 +158,8 @@ def post_role() -> APIResponse:
"""Create a new role."""
security_manager = cast("FabAuthManager", get_auth_manager()).security_manager
body = request.json
if body is None:
raise BadRequest("Request body is required")
try:
data = role_schema.load(body)
except ValidationError as err:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) ->
@requires_access_custom_view("POST", permissions.RESOURCE_USER)
def post_user() -> APIResponse:
"""Create a new user."""
if request.json is None:
raise BadRequest("Request body is required")
try:
data = user_schema.load(request.json)
except ValidationError as e:
Expand Down Expand Up @@ -131,6 +133,8 @@ def post_user() -> APIResponse:
@requires_access_custom_view("PUT", permissions.RESOURCE_USER)
def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse:
"""Update a user."""
if request.json is None:
raise BadRequest("Request body is required")
try:
data = user_schema.load(request.json)
except ValidationError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,37 @@ def __repr__(self):
"ab_permission_view_role",
Model.metadata,
Column("id", Integer, primary_key=True),
Column("permission_view_id", Integer, ForeignKey("ab_permission_view.id")),
Column("role_id", Integer, ForeignKey("ab_role.id")),
Column(
"permission_view_id",
Integer,
ForeignKey("ab_permission_view.id", ondelete="CASCADE"),
),
Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")),
UniqueConstraint("permission_view_id", "role_id"),
)

assoc_user_group = Table(
"ab_user_group",
Model.metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")),
Column("group_id", Integer, ForeignKey("ab_group.id", ondelete="CASCADE")),
UniqueConstraint("user_id", "group_id"),
Index("idx_user_id", "user_id"),
Index("idx_user_group_id", "group_id"),
)

assoc_group_role = Table(
"ab_group_role",
Model.metadata,
Column("id", Integer, primary_key=True),
Column("group_id", Integer, ForeignKey("ab_group.id", ondelete="CASCADE")),
Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")),
UniqueConstraint("group_id", "role_id"),
Index("idx_group_id", "group_id"),
Index("idx_group_role_id", "role_id"),
)


class Role(Model):
"""Represents a user role to which permissions can be assigned."""
Expand All @@ -115,7 +141,29 @@ class Role(Model):

id = Column(Integer, primary_key=True)
name = Column(String(64), unique=True, nullable=False)
permissions = relationship("Permission", secondary=assoc_permission_role, backref="role", lazy="joined")
permissions = relationship(
"Permission",
secondary=assoc_permission_role,
backref="role",
lazy="joined",
passive_deletes=True,
)

def __repr__(self):
return self.name


class Group(Model):
"""Represents a user group."""

__tablename__ = "ab_group"

id = Column(Integer, primary_key=True)
name = Column(String(100), unique=True, nullable=False)
label = Column(String(150))
description = Column(String(512))
users = relationship("User", secondary=assoc_user_group, backref="groups", passive_deletes=True)
roles = relationship("Role", secondary=assoc_group_role, backref="groups", passive_deletes=True)

def __repr__(self):
return self.name
Expand Down Expand Up @@ -148,8 +196,8 @@ def __repr__(self):
"ab_user_role",
Model.metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("role_id", Integer, ForeignKey("ab_role.id")),
Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")),
Column("role_id", Integer, ForeignKey("ab_role.id", ondelete="CASCADE")),
UniqueConstraint("user_id", "role_id"),
)

Expand All @@ -170,7 +218,9 @@ class User(Model, BaseUser):
last_login = Column(DateTime)
login_count = Column(Integer)
fail_login_count = Column(Integer)
roles = relationship("Role", secondary=assoc_user_role, backref="user", lazy="selectin")
roles = relationship(
"Role", secondary=assoc_user_role, backref="user", lazy="selectin", passive_deletes=True
)
created_on = Column(DateTime, default=datetime.datetime.now, nullable=True)
changed_on = Column(DateTime, default=datetime.datetime.now, nullable=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ def roles(self):
if not self._roles:
public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None)
self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set()
return self._roles
return list(self._roles)

@roles.setter
def roles(self, roles):
self._roles = roles
self._perms = set()

@property
def groups(self):
return []

@property
def perms(self):
if not self._perms:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
AuthOIDView,
AuthRemoteUserView,
RegisterUserModelView,
UserGroupModelView,
)
from flask_babel import lazy_gettext
from flask_jwt_extended import JWTManager
Expand All @@ -71,6 +72,7 @@
from airflow.models import DagBag
from airflow.providers.fab.auth_manager.models import (
Action,
Group,
Permission,
RegisterUser,
Resource,
Expand Down Expand Up @@ -100,10 +102,7 @@
from airflow.providers.fab.auth_manager.views.user_stats import CustomUserStatsChartView
from airflow.providers.fab.www.security import permissions
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
from airflow.providers.fab.www.session import (
AirflowDatabaseSessionInterface,
AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface,
)
from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface
from airflow.security.permissions import RESOURCE_BACKFILL

if TYPE_CHECKING:
Expand Down Expand Up @@ -149,6 +148,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2):
""" Models """
user_model = User
role_model = Role
group_model = Group
action_model = Action
resource_model = Resource
permission_model = Permission
Expand All @@ -173,6 +173,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2):
actionmodelview = ActionModelView
permissionmodelview = PermissionPairModelView
rolemodelview = CustomRoleModelView
groupmodelview = UserGroupModelView
registeruser_model = RegisterUser
registerusermodelview = RegisterUserModelView
resourcemodelview = ResourceModelView
Expand Down Expand Up @@ -450,7 +451,7 @@ def register_views(self):
role_view = self.appbuilder.add_view(
self.rolemodelview,
"List Roles",
icon="fa-group",
icon="fa-user-gear",
label=lazy_gettext("List Roles"),
category="Security",
category_icon="fa-cogs",
Expand Down Expand Up @@ -532,12 +533,7 @@ def reset_password(self, userid: int, password: str) -> bool:
return self.update_user(user)

def reset_user_sessions(self, user: User) -> None:
if isinstance(
self.appbuilder.get_app.session_interface, AirflowDatabaseSessionInterface
) or isinstance(
self.appbuilder.get_app.session_interface,
FabAirflowDatabaseSessionInterface,
):
if isinstance(self.appbuilder.get_app.session_interface, AirflowDatabaseSessionInterface):
interface = self.appbuilder.get_app.session_interface
session = interface.db.session
user_session_model = interface.sql_session_model
Expand Down Expand Up @@ -859,6 +855,7 @@ def _init_data_model(self):
self.registerusermodelview.datamodel = SQLAInterface(self.registeruser_model)

self.rolemodelview.datamodel = SQLAInterface(self.role_model)
self.groupmodelview.datamodel = SQLAInterface(self.group_model)
self.actionmodelview.datamodel = SQLAInterface(self.action_model)
self.resourcemodelview.datamodel = SQLAInterface(self.resource_model)
self.permissionmodelview.datamodel = SQLAInterface(self.permission_model)
Expand All @@ -875,7 +872,8 @@ def create_db(self):
try:
engine = self.get_session.get_bind(mapper=None, clause=None)
inspector = inspect(engine)
if "ab_user" not in inspector.get_table_names():
existing_tables = inspector.get_table_names()
if "ab_user" not in existing_tables or "ab_group" not in existing_tables:
log.info(const.LOGMSG_INF_SEC_NO_DB)
Base.metadata.create_all(engine)
log.info(const.LOGMSG_INF_SEC_ADD_DB)
Expand Down Expand Up @@ -1311,15 +1309,20 @@ def get_public_role(self):

def add_user(
self,
username,
first_name,
last_name,
email,
role,
password="",
hashed_password="",
username: str,
first_name: str,
last_name: str,
email: str,
role: list[Role] | Role | None = None,
password: str = "",
hashed_password: str = "",
groups: list[Group] | None = None,
):
"""Create a user."""
roles: list[Role] = []
if role:
roles = role if isinstance(role, list) else [role]

try:
user = self.user_model()
user.first_name = first_name
Expand All @@ -1328,7 +1331,8 @@ def add_user(
user.email = email
user.active = True
self.get_session.add(user)
user.roles = role if isinstance(role, list) else [role]
user.roles = roles
user.groups = groups or []
if hashed_password:
user.password = hashed_password
else:
Expand Down Expand Up @@ -1692,7 +1696,7 @@ def get_user_roles(user=None):
"""
if user is None:
user = g.user
return user.roles
return user.roles + [role for group in user.groups for role in group.roles]

"""
--------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from base64 import b64encode

import pytest
from flask_login import current_user

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_pools
Expand Down Expand Up @@ -74,6 +73,5 @@ def test_success(self):

with self.app.test_client() as test_client:
response = test_client.get("/fab/v1/users", headers={"Authorization": token})
assert current_user.email == "test@fab.org"

assert response.status_code == 200
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

class TestAnonymousUser:
def test_roles(self):
roles = {"role1"}
roles = ["role1"]
user = AnonymousUser()
user.roles = roles
assert user.roles == roles
Expand Down
12 changes: 6 additions & 6 deletions providers/fab/tests/unit/fab/auth_manager/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def test_verify_default_anon_user_has_no_accessible_dag_ids(
mock_is_logged_in.return_value = False
user = AnonymousUser()
app.config["AUTH_ROLE_PUBLIC"] = "Public"
assert security_manager.get_user_roles(user) == {security_manager.get_public_role()}
assert security_manager.get_user_roles(user) == [security_manager.get_public_role()]

with _create_dag_model_context("test_dag_id", session, security_manager):
security_manager.sync_roles()
Expand All @@ -365,7 +365,7 @@ def test_verify_default_anon_user_has_no_access_to_specific_dag(app, session, se
with app.app_context():
user = AnonymousUser()
app.config["AUTH_ROLE_PUBLIC"] = "Public"
assert security_manager.get_user_roles(user) == {security_manager.get_public_role()}
assert security_manager.get_user_roles(user) == [security_manager.get_public_role()]

dag_id = "test_dag_id"
with _create_dag_model_context(dag_id, session, security_manager):
Expand All @@ -392,7 +392,7 @@ def test_verify_anon_user_with_admin_role_has_all_dag_access(
mock_is_logged_in.return_value = False
user = AnonymousUser()

assert security_manager.get_user_roles(user) == {security_manager.get_public_role()}
assert security_manager.get_user_roles(user) == [security_manager.get_public_role()]

security_manager.sync_roles()

Expand All @@ -408,7 +408,7 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag(

# Call `.get_user_roles` bc `user` is a mock and the `user.roles` prop needs to be set.
user.roles = security_manager.get_user_roles(user)
assert user.roles == {security_manager.get_public_role()}
assert user.roles == [security_manager.get_public_role()]

test_dag_ids = ["test_dag_id_1", "test_dag_id_2", "test_dag_id_3", "test_dag_id_4.with_dot"]

Expand All @@ -425,8 +425,8 @@ def test_verify_anon_user_with_admin_role_has_access_to_each_dag(
def test_get_user_roles(app_builder, security_manager):
user = mock.MagicMock()
roles = app_builder.sm.find_role("Admin")
user.roles = roles
assert security_manager.get_user_roles(user) == roles
user.roles = [roles]
assert security_manager.get_user_roles(user) == [roles]


def test_get_user_roles_for_anonymous_user(app, security_manager):
Expand Down
Loading