diff --git a/server/mergin/sync/models.py b/server/mergin/sync/models.py index 690b2860..e7c69159 100644 --- a/server/mergin/sync/models.py +++ b/server/mergin/sync/models.py @@ -89,6 +89,7 @@ def __init__( self.public = kwargs.get("public", False) latest_files = LatestProjectFiles(project=self) db.session.add(latest_files) + self.set_role(creator.id, ProjectRole.OWNER) @property def storage(self): @@ -280,6 +281,33 @@ def unset_role(self, user_id: int) -> None: if member: self.project_users.remove(member) + def members_by_role(self, role: ProjectRole) -> List[int]: + """Project members' ids with at least required role (or higher)""" + return [u.user_id for u in self.project_users if ProjectRole(u.role) >= role] + + def bulk_roles_update(self, access: Dict) -> Set[int]: + """Update roles from access lists and return users ids of those affected by any action""" + id_diffs = [] + for role in list(ProjectRole.__reversed__()): + # we might not want to modify all roles + if role not in access: + continue + + for user_id in access.get(role): + if self.get_role(user_id) != role: + self.set_role(user_id, role) + id_diffs.append(user_id) + + # make sure we do not have other user ids than in the list at this role + for user in self.project_users: + if ProjectRole(user.role) == role and user.user_id not in access.get( + role + ): + self.unset_role(user.user_id) + id_diffs.append(user.user_id) + + return set(id_diffs) + class ProjectRole(Enum): """Project roles ordered by rank (do not change)""" @@ -294,6 +322,14 @@ def __ge__(self, other): members = list(ProjectRole.__members__) return members.index(self.name) >= members.index(other.name) + def __gt__(self, other): + members = list(ProjectRole.__members__) + return members.index(self.name) > members.index(other.name) + + def __lt__(self, other): + members = list(ProjectRole.__members__) + return members.index(self.name) < members.index(other.name) + @dataclass class ProjectAccessDetail: diff --git a/server/mergin/sync/permissions.py b/server/mergin/sync/permissions.py index d28ad48b..4e3901d5 100644 --- a/server/mergin/sync/permissions.py +++ b/server/mergin/sync/permissions.py @@ -10,6 +10,7 @@ from sqlalchemy import or_ from .utils import is_valid_uuid +from ..app import db from ..auth.models import User from .models import Project, Upload, ProjectRole, ProjectUser @@ -62,10 +63,8 @@ def query(cls, user, as_admin=True, public=True): if user.is_authenticated and user.is_admin and as_admin: return Project.query - query = ( - Project.query.join(ProjectUser) - .filter(Project.storage_params.isnot(None)) - .filter(Project.removed_at.is_(None)) + query = Project.query.filter(Project.storage_params.isnot(None)).filter( + Project.removed_at.is_(None) ) if user.is_authenticated and user.active: all_workspaces = current_app.ws_handler.list_user_workspaces( @@ -76,19 +75,24 @@ def query(cls, user, as_admin=True, public=True): for ws in all_workspaces if ws.user_has_permissions(user, "read") ] + subquery = ( + db.session.query(ProjectUser.project_id) + .filter(ProjectUser.user_id == user.id) + .subquery() + ) if public: query = query.filter( or_( Project.public.is_(True), Project.workspace_id.in_(user_workspace_ids), - ProjectUser.user_id == user.id, + Project.id.in_(subquery), ) ) else: query = query.filter( or_( Project.workspace_id.in_(user_workspace_ids), - ProjectUser.user_id == user.id, + Project.id.in_(subquery), ) ) else: diff --git a/server/mergin/sync/public_api.yaml b/server/mergin/sync/public_api.yaml index c066b400..48d39c13 100644 --- a/server/mergin/sync/public_api.yaml +++ b/server/mergin/sync/public_api.yaml @@ -232,6 +232,7 @@ paths: summary: Update an existing project description: Updates 'public' flag and access list for project operationId: update_project + deprecated: true requestBody: description: Data to be updated required: true diff --git a/server/mergin/sync/public_api_controller.py b/server/mergin/sync/public_api_controller.py index 18ab2ab8..a325ef67 100644 --- a/server/mergin/sync/public_api_controller.py +++ b/server/mergin/sync/public_api_controller.py @@ -26,11 +26,13 @@ ) from pygeodiff import GeoDiffLibError from flask_login import current_user -from sqlalchemy import and_, desc, asc, text +from sqlalchemy import and_, desc, asc, text, func, select from sqlalchemy.exc import IntegrityError from binaryornot.check import is_binary from gevent import sleep import base64 + +from sqlalchemy.orm import load_only from werkzeug.exceptions import HTTPException from ..app import db from ..auth import auth_required @@ -42,6 +44,8 @@ PushChangeType, FileHistory, ProjectFilePath, + ProjectUser, + ProjectRole, ) from .files import ( UploadChanges, @@ -93,14 +97,19 @@ def parse_project_access_update_request(access: Dict) -> Dict: """Parse raw project access update request and filter out invalid entries. New access can be specified either by list of usernames or ids -> convert only to ids fur further processing. + Converted lists are flattened, e.g. user id is unique within all keys. Bear in mind roles keys are optional, + if missing, it means that we do not want to do any changes there. + + Deprecated. Used only in legacy PUT /v1/project endpoint for project access replacement. :Example: >>> parse_project_access_update_request({"writersnames": ["john"], "readersnames": ["john, jack, bob.inactive"]}) - {"writers": [1], "readers": [1,2], "invalid_usernames": ["bob.inactive"], "invalid_ids":[]} + {"ProjectRole.WRITER": [1], "ProjectRole.READER": [2], "invalid_usernames": ["bob.inactive"], "invalid_ids":[]} >>> parse_project_access_update_request({"writers": [1], "readers": [1,2,3]}) - {"writers": [1], "readers": [1,2], "invalid_usernames": [], "invalid_ids":[3]"} + {"ProjectRole.WRITER": [1], "ProjectRole.READER": [2], "invalid_usernames": [], "invalid_ids":[3]"} """ + resp = {} parsed_access = {} names = set( access.get("ownersnames", []) @@ -137,9 +146,23 @@ def parse_project_access_update_request(access: Dict) -> Dict: # use legacy option elif key in access: parsed_access[key] = [id for id in access.get(key) if id in valid_ids] - parsed_access["invalid_usernames"] = list(names.difference(valid_usernames)) - parsed_access["invalid_ids"] = list(ids.difference(valid_ids)) - return parsed_access + + # remove 'inheritance', prepare final map for direct assignments + processed_ids = [] + for key in ("owners", "writers", "editors", "readers"): + # we might not want to modify all roles + if key not in parsed_access: + continue + role = ProjectRole(key[:-1]) + resp[role] = [] + for user_id in parsed_access.get(key): + if user_id not in processed_ids: + resp[role].append(user_id) + processed_ids.append(user_id) + + resp["invalid_usernames"] = list(names.difference(valid_usernames)) + resp["invalid_ids"] = list(ids.difference(valid_ids)) + return resp @auth_required @@ -501,15 +524,13 @@ def get_projects_by_names(): # noqa: E501 Project.workspace_id == workspace.id, Project.name == name ).first() if result: - # FIXME - # user_ids = ( - # result.access.owners + result.access.writers + result.access.readers - # ) - # users_map = { - # u.id: u.username - # for u in User.query.filter(User.id.in_(set(user_ids))).all() - # } - users_map = None + users_map = { + u.id: u.username + for u in User.query.select_from(ProjectUser) + .join(User) + .filter(ProjectUser.project_id == result.id) + .all() + } workspaces_map = {workspace.id: workspace.name} ctx = {"users_map": users_map, "workspaces_map": workspaces_map} results[project] = ProjectListSchema(context=ctx).dump(result) @@ -535,19 +556,19 @@ def get_projects_by_uuids(uuids): # noqa: E501 if len(proj_ids) > 10: abort(400, "Too many projects") - user_ids = [] - ws_ids = [] projects = ( projects_query(ProjectPermissions.Read, as_admin=False) .filter(Project.id.in_(proj_ids)) .all() ) - for p in projects: - # FIXME - # user_ids.extend(p.access.owners + p.access.writers + p.access.readers) - ws_ids.append(p.workspace_id) + ws_ids = set([p.workspace_id for p in projects]) + projects_ids = [p.id for p in projects] users_map = { - u.id: u.username for u in User.query.filter(User.id.in_(set(user_ids))).all() + u.id: u.username + for u in User.query.select_from(ProjectUser) + .join(User) + .filter(ProjectUser.project_id.in_(projects_ids)) + .all() } workspaces_map = {w.id: w.name for w in current_app.ws_handler.get_by_ids(ws_ids)} ctx = {"users_map": users_map, "workspaces_map": workspaces_map} @@ -622,17 +643,16 @@ def get_paginated_projects( only_public, ) result = projects.paginate(page, per_page).items - total = projects.paginate(page, per_page).total + total = projects.paginate().total # create user map id:username passed to project schema to minimize queries to db - user_ids = [] - for p in result: - # FIXME - # user_ids.extend(p.access.owners + p.access.writers + p.access.readers) - pass - + projects_ids = [p.id for p in result] users_map = { - u.id: u.username for u in User.query.filter(User.id.in_(set(user_ids))).all() + u.id: u.username + for u in User.query.select_from(ProjectUser) + .join(User) + .filter(ProjectUser.project_id.in_(projects_ids)) + .all() } ws_ids = [p.workspace_id for p in projects] workspaces_map = {w.id: w.name for w in current_app.ws_handler.get_by_ids(ws_ids)} @@ -659,9 +679,11 @@ def update_project(namespace, project_name): # noqa: E501 # pylint: disable=W0 :rtype: ProjectDetail """ project = require_project(namespace, project_name, ProjectPermissions.Update) - access = request.json.get("access", {}) - - id_diffs, error = current_app.ws_handler.update_project_members(project, access) + parsed_access = parse_project_access_update_request(request.json.get("access", {})) + # get set of modified user_ids and possible (custom) errors + id_diffs, error = current_app.ws_handler.update_project_members( + project, parsed_access + ) if not id_diffs and error: # nothing was done but there are errors diff --git a/server/mergin/sync/schemas.py b/server/mergin/sync/schemas.py index 3c480ad2..5ad0629e 100644 --- a/server/mergin/sync/schemas.py +++ b/server/mergin/sync/schemas.py @@ -9,16 +9,25 @@ from .files import ProjectFileSchema, FileSchema from .permissions import ProjectPermissions -from .models import Project, ProjectVersion, AccessRequest, FileHistory, PushChangeType +from .models import ( + Project, + ProjectVersion, + AccessRequest, + FileHistory, + PushChangeType, + ProjectRole, +) from ..app import DateTimeWithZ, ma from ..auth.models import User class ProjectAccessSchema(ma.SQLAlchemyAutoSchema): - owners = fields.List(fields.Integer()) - writers = fields.List(fields.Integer()) - editors = fields.List(fields.Integer()) - readers = fields.List(fields.Integer()) + """Schema for legacy response with user arrays""" + + owners = fields.Function(lambda obj: obj.members_by_role(ProjectRole.OWNER)) + writers = fields.Function(lambda obj: obj.members_by_role(ProjectRole.WRITER)) + editors = fields.Function(lambda obj: obj.members_by_role(ProjectRole.EDITOR)) + readers = fields.Function(lambda obj: obj.members_by_role(ProjectRole.READER)) public = fields.Boolean() @post_dump @@ -106,7 +115,7 @@ class ProjectSchemaForVersion(ma.SQLAlchemyAutoSchema): uploads = fields.Method("_uploads") name = fields.Function(lambda obj: obj.project.name) namespace = fields.Function(lambda obj: obj.project.workspace.name) - access = fields.Method("_access") + access = fields.Function(lambda obj: ProjectAccessSchema().dump(obj.project)) permissions = fields.Method("_permissions") disk_usage = fields.Method("_disk_usage") files = fields.Nested(ProjectFileSchema(), many=True) @@ -124,9 +133,6 @@ def _role(self, obj): def _uploads(self, obj): return [u.id for u in obj.project.uploads.all()] - def _access(self, obj): - return ProjectAccessSchema().dump(obj.project) - def _permissions(self, obj): return project_user_permissions(obj.project) @@ -156,7 +162,7 @@ class Meta: class ProjectSchema(ma.SQLAlchemyAutoSchema): id = fields.UUID() files = fields.Nested(ProjectFileSchema(), many=True) - access = fields.Nested(ProjectAccessSchema()) + access = fields.Function(lambda obj: ProjectAccessSchema().dump(obj)) permissions = fields.Function(project_user_permissions) version = fields.Function(lambda obj: ProjectVersion.to_v_name(obj.latest_version)) namespace = fields.Function(lambda obj: obj.workspace.name) @@ -185,7 +191,7 @@ class ProjectListSchema(ma.SQLAlchemyAutoSchema): id = fields.UUID() name = fields.Str() namespace = fields.Method("get_workspace_name") - access = fields.Nested(ProjectAccessSchema()) + access = fields.Function(lambda obj: ProjectAccessSchema().dump(obj)) permissions = fields.Function(project_user_permissions) version = fields.Function(lambda obj: ProjectVersion.to_v_name(obj.latest_version)) updated = fields.Method("get_updated") diff --git a/server/mergin/sync/workspace.py b/server/mergin/sync/workspace.py index b587004e..be0d6ac5 100644 --- a/server/mergin/sync/workspace.py +++ b/server/mergin/sync/workspace.py @@ -5,8 +5,7 @@ from datetime import datetime, timedelta, timezone from typing import Dict, Tuple, Optional, Set, List from flask_login import current_user -from sqlalchemy import or_, and_, Column, literal, extract -from sqlalchemy.orm import joinedload +from sqlalchemy import Column, literal, extract from .errors import UpdateProjectAccessError from .models import ( @@ -17,7 +16,6 @@ ProjectUser, ) from .permissions import projects_query, ProjectPermissions -from .public_api_controller import parse_project_access_update_request from ..app import db from ..auth.models import User from ..config import Configuration @@ -166,8 +164,7 @@ def filter_projects( ): if only_public: projects = ( - Project.query.join(ProjectUser) - .filter(Project.storage_params.isnot(None)) + Project.query.filter(Project.storage_params.isnot(None)) .filter(Project.removed_at.is_(None)) .filter(Project.public.is_(True)) ) @@ -183,23 +180,17 @@ def filter_projects( if flag == "created": projects = projects.filter(Project.creator_id == user.id) if flag == "shared": - # check global read permissions + projects = projects.filter(Project.creator_id != user.id) + # check global read permissions or direct project permissions if workspace.user_has_permissions(user, "read"): - read_access_workspace_id = workspace.id + projects = projects.filter(Project.workspace_id == workspace.id) else: - read_access_workspace_id = None - projects = projects.filter( - or_( - and_( - ProjectUser.user_id == user.id, - Project.creator_id != user.id, - ), - and_( - Project.workspace_id == read_access_workspace_id, - Project.creator_id != user.id, - ), + subquery = ( + db.session.query(ProjectUser.project_id) + .filter(ProjectUser.user_id == user.id) + .subquery() ) - ) + projects = projects.filter(Project.id.in_(subquery)) if name: projects = projects.filter(Project.name.ilike("%{}%".format(name))) @@ -288,15 +279,13 @@ def update_project_members( ) -> Tuple[Set[int], Optional[UpdateProjectAccessError]]: """Update project members doing bulk access update""" error = None - parsed_access = parse_project_access_update_request(access) - # FIXME - id_diffs = set() - # id_diffs = project.access.bulk_update(parsed_access) + id_diffs = project.bulk_roles_update(access) db.session.add(project) db.session.commit() - if parsed_access.get("invalid_usernames") or parsed_access.get("invalid_ids"): + + if access.get("invalid_usernames") or access.get("invalid_ids"): error = UpdateProjectAccessError( - parsed_access["invalid_usernames"], parsed_access["invalid_ids"] + access["invalid_usernames"], access["invalid_ids"] ) return id_diffs, error diff --git a/server/mergin/tests/test_project_controller.py b/server/mergin/tests/test_project_controller.py index c1587867..21387c19 100644 --- a/server/mergin/tests/test_project_controller.py +++ b/server/mergin/tests/test_project_controller.py @@ -18,7 +18,6 @@ import re from flask_login import current_user -from unittest.mock import patch from pygeodiff import GeoDiff from flask import url_for, current_app import tempfile @@ -37,7 +36,8 @@ ProjectFilePath, ) from ..sync.files import ChangesSchema -from ..sync.schemas import ProjectListSchema +from ..sync.permissions import projects_query +from ..sync.schemas import ProjectListSchema, ProjectSchema from ..sync.utils import generate_checksum, is_versioned_file from ..auth.models import User, UserProfile @@ -176,7 +176,7 @@ def test_get_paginated_projects(client): resp_data = json.loads(resp.data) assert len(resp_data.get("projects")) == 10 assert resp_data.get("count") == 15 - assert "foo8" in resp_data.get("projects")[9]["name"] + assert "foo7" in resp_data.get("projects")[9]["name"] assert "v0" == resp_data.get("projects")[9]["version"] resp = client.get( @@ -199,7 +199,7 @@ def test_get_paginated_projects(client): ) resp_data = json.loads(resp.data) assert len(resp_data.get("projects")) == 5 - assert "foo13" in resp_data.get("projects")[-1]["name"] + assert "foo12" in resp_data.get("projects")[-1]["name"] # tests backward compatibility sort resp_alt = client.get( "/v1/project/paginated?page=2&per_page=10&order_by=namespace&descending=false" @@ -296,7 +296,7 @@ def test_get_paginated_projects(client): Configuration.GLOBAL_READ = False Configuration.GLOBAL_WRITE = False Configuration.GLOBAL_ADMIN = False - # add new user and let him create one project + # add new user and let him create one project (total 15+1) user2 = add_user("user2", "ilovemergin") assert not test_workspace.user_has_permissions(user2, "read") create_project("created", test_workspace, user2) @@ -317,6 +317,8 @@ def test_get_paginated_projects(client): # make user reader of all projects Configuration.GLOBAL_READ = True + resp = client.get("/v1/project/paginated?page=1&per_page=10&flag=created") + assert resp.json["count"] == 1 resp = client.get("/v1/project/paginated?page=1&per_page=20&flag=shared") resp_data = json.loads(resp.data) assert resp.status_code == 200 @@ -628,6 +630,7 @@ def test_update_project(client): project = Project.query.filter_by( name=test_project, workspace_id=test_workspace_id ).first() + creator = User.query.get(project.creator_id) # need for private project project.public = False db.session.add(project) @@ -641,16 +644,7 @@ def test_update_project(client): db.session.commit() # add tests user as reader to project - data = { - "access": { - "readers": [ - u.id - for u in project.project_users - if u.role == ProjectRole.READER.value - ] - + [test_user.id] - } - } + data = {"access": {"readers": [test_user.id]}} resp = client.put( "/v1/project/{}/{}".format(test_workspace_name, test_project), data=json.dumps(data), @@ -660,13 +654,7 @@ def test_update_project(client): assert project.get_role(test_user.id) is ProjectRole.READER # add tests user as writer to project - current_writers = [ - u.id for u in project.project_users if u.role == ProjectRole.WRITER.value - ] - writers = [ - u.username for u in User.query.filter(User.id.in_(current_writers)).all() - ] - data = {"access": {"writersnames": writers + [test_user.username]}} + data = {"access": {"writersnames": [test_user.username]}} resp = client.put( "/v1/project/{}/{}".format(test_workspace_name, test_project), data=json.dumps(data), @@ -674,6 +662,7 @@ def test_update_project(client): ) assert resp.status_code == 200 assert project.get_role(test_user.id) is ProjectRole.WRITER + assert project.get_role(creator.id) is ProjectRole.OWNER # try to remove project creator from owners data = {"access": {"owners": [test_user.id]}} @@ -683,15 +672,10 @@ def test_update_project(client): headers=json_headers, ) assert resp.status_code == 200 + assert not project.get_role(creator.id) # try to add non-existing user - current_readers = [ - u.id for u in project.project_users if u.role == ProjectRole.READER.value - ] - readers = [ - user.username for user in User.query.filter(User.id.in_(current_readers)).all() - ] - data = {"access": {"readersnames": readers + ["not-found-user"]}} + data = {"access": {"readersnames": ["not-found-user"]}} resp = client.put( f"/v1/project/{test_workspace_name}/{test_project}", data=json.dumps(data), @@ -703,18 +687,16 @@ def test_update_project(client): assert resp.json["invalid_usernames"] == ["not-found-user"] # try to add non-existing user plus make some valid update -> only partial success - current_readers = [ - u.id for u in project.project_users if u.role == ProjectRole.READER.value - ] - readers = [ - user.username for user in User.query.filter(User.id.in_(current_readers)).all() + current_users = [u.user_id for u in project.project_users] + usernames = [ + user.username for user in User.query.filter(User.id.in_(current_users)).all() ] data = { "access": { - "readersnames": readers + ["not-found-user"], - "editorsnames": readers, - "writersnames": readers, - "ownersnames": readers, + "readersnames": ["not-found-user"], + "editorsnames": usernames + [creator.username], + "writersnames": usernames + [creator.username], + "ownersnames": usernames, } } resp = client.put( @@ -724,6 +706,8 @@ def test_update_project(client): ) assert resp.status_code == 207 assert resp.json["code"] == "UpdateProjectAccessError" + assert project.get_role(test_user.id) is ProjectRole.OWNER + assert project.get_role(creator.id) is ProjectRole.WRITER # login as a new project owner and check permissions login(client, test_user.username, "tester")