From b989a2f7ef5f27a1dc624bb6136d05faa688bb60 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Feb 2026 01:00:42 +0000 Subject: [PATCH 1/3] feat: authorisation, first design --- .../migrations/versions/add_authorization.py | 71 ++++++++ server/osa/application/api/rest/app.py | 16 +- server/osa/application/api/v1/errors.py | 7 + server/osa/application/api/v1/routes/admin.py | 95 ++++++++++ server/osa/application/api/v1/routes/auth.py | 11 +- server/osa/domain/auth/command/assign_role.py | 51 ++++++ server/osa/domain/auth/command/login.py | 6 + server/osa/domain/auth/command/revoke_role.py | 36 ++++ server/osa/domain/auth/command/token.py | 5 + server/osa/domain/auth/model/principal.py | 26 +++ server/osa/domain/auth/model/role.py | 17 ++ .../osa/domain/auth/model/role_assignment.py | 49 ++++++ .../osa/domain/auth/port/role_repository.py | 33 ++++ .../osa/domain/auth/query/get_user_roles.py | 56 ++++++ .../osa/domain/auth/service/authorization.py | 49 ++++++ server/osa/domain/auth/util/di/provider.py | 62 +++++++ .../osa/domain/deposition/command/create.py | 5 + .../domain/deposition/command/delete_files.py | 5 + .../osa/domain/deposition/command/submit.py | 5 + .../osa/domain/deposition/command/update.py | 5 + .../osa/domain/deposition/command/upload.py | 6 + .../osa/domain/deposition/model/aggregate.py | 2 + .../domain/shared/authorization/__init__.py | 0 .../osa/domain/shared/authorization/action.py | 57 ++++++ .../shared/authorization/authorized_repo.py | 40 +++++ .../domain/shared/authorization/guarded.py | 41 +++++ .../osa/domain/shared/authorization/policy.py | 83 +++++++++ .../domain/shared/authorization/policy_set.py | 164 ++++++++++++++++++ .../domain/shared/authorization/startup.py | 75 ++++++++ server/osa/domain/shared/command.py | 86 ++++++++- server/osa/domain/shared/query.py | 88 +++++++++- server/osa/infrastructure/auth/di.py | 7 + .../infrastructure/auth/role_repository.py | 73 ++++++++ .../persistence/mappers/deposition.py | 5 + .../osa/infrastructure/persistence/tables.py | 19 ++ .../tests/unit/domain/deposition/__init__.py | 0 .../test_deposition_service_auth.py | 131 ++++++++++++++ .../domain/shared/authorization/__init__.py | 0 .../shared/authorization/test_auth_gate.py | 141 +++++++++++++++ .../authorization/test_authorization_audit.py | 63 +++++++ .../shared/authorization/test_guarded.py | 64 +++++++ .../shared/authorization/test_policy.py | 90 ++++++++++ .../shared/authorization/test_policy_set.py | 113 ++++++++++++ .../authorization/test_role_hierarchy.py | 60 +++++++ .../authorization/test_startup_validation.py | 103 +++++++++++ 45 files changed, 2106 insertions(+), 15 deletions(-) create mode 100644 server/migrations/versions/add_authorization.py create mode 100644 server/osa/application/api/v1/routes/admin.py create mode 100644 server/osa/domain/auth/command/assign_role.py create mode 100644 server/osa/domain/auth/command/revoke_role.py create mode 100644 server/osa/domain/auth/model/principal.py create mode 100644 server/osa/domain/auth/model/role.py create mode 100644 server/osa/domain/auth/model/role_assignment.py create mode 100644 server/osa/domain/auth/port/role_repository.py create mode 100644 server/osa/domain/auth/query/get_user_roles.py create mode 100644 server/osa/domain/auth/service/authorization.py create mode 100644 server/osa/domain/shared/authorization/__init__.py create mode 100644 server/osa/domain/shared/authorization/action.py create mode 100644 server/osa/domain/shared/authorization/authorized_repo.py create mode 100644 server/osa/domain/shared/authorization/guarded.py create mode 100644 server/osa/domain/shared/authorization/policy.py create mode 100644 server/osa/domain/shared/authorization/policy_set.py create mode 100644 server/osa/domain/shared/authorization/startup.py create mode 100644 server/osa/infrastructure/auth/role_repository.py create mode 100644 server/tests/unit/domain/deposition/__init__.py create mode 100644 server/tests/unit/domain/deposition/test_deposition_service_auth.py create mode 100644 server/tests/unit/domain/shared/authorization/__init__.py create mode 100644 server/tests/unit/domain/shared/authorization/test_auth_gate.py create mode 100644 server/tests/unit/domain/shared/authorization/test_authorization_audit.py create mode 100644 server/tests/unit/domain/shared/authorization/test_guarded.py create mode 100644 server/tests/unit/domain/shared/authorization/test_policy.py create mode 100644 server/tests/unit/domain/shared/authorization/test_policy_set.py create mode 100644 server/tests/unit/domain/shared/authorization/test_role_hierarchy.py create mode 100644 server/tests/unit/domain/shared/authorization/test_startup_validation.py diff --git a/server/migrations/versions/add_authorization.py b/server/migrations/versions/add_authorization.py new file mode 100644 index 0000000..0a9977b --- /dev/null +++ b/server/migrations/versions/add_authorization.py @@ -0,0 +1,71 @@ +"""add_authorization + +Add role_assignments table and owner_id column to depositions. + +Revision ID: add_authorization +Revises: add_auth_tables +Create Date: 2026-02-06 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_authorization" +down_revision: Union[str, Sequence[str], None] = "add_auth_tables" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add authorization tables and columns.""" + # ROLE ASSIGNMENTS TABLE + op.create_table( + "role_assignments", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("role", sa.String(32), nullable=False), + sa.Column("assigned_by", sa.String(), nullable=False), + sa.Column("assigned_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["assigned_by"], + ["users.id"], + ), + sa.UniqueConstraint("user_id", "role", name="uq_role_assignments_user_role"), + ) + op.create_index("ix_role_assignments_user_id", "role_assignments", ["user_id"]) + + # ADD owner_id TO DEPOSITIONS (nullable initially for existing data) + op.add_column( + "depositions", + sa.Column("owner_id", sa.String(), nullable=True), + ) + op.create_foreign_key( + "fk_depositions_owner_id", + "depositions", + "users", + ["owner_id"], + ["id"], + ) + op.create_index("idx_depositions_owner_id", "depositions", ["owner_id"]) + + +def downgrade() -> None: + """Remove authorization tables and columns.""" + # DEPOSITIONS owner_id + op.drop_index("idx_depositions_owner_id", table_name="depositions") + op.drop_constraint("fk_depositions_owner_id", "depositions", type_="foreignkey") + op.drop_column("depositions", "owner_id") + + # ROLE ASSIGNMENTS + op.drop_index("ix_role_assignments_user_id", table_name="role_assignments") + op.drop_table("role_assignments") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 735e836..e9d6fef 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -6,9 +6,19 @@ from fastapi.responses import JSONResponse from osa.application.api.v1.errors import map_osa_error -from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation +from osa.application.api.v1.routes import ( + admin, + auth, + events, + health, + records, + search, + stats, + validation, +) from osa.application.di import create_container from osa.config import Config, configure_logging +from osa.domain.shared.authorization.startup import validate_all_handlers from osa.domain.shared.error import OSAError from osa.infrastructure.event.worker import WorkerPool from osa.infrastructure.source.discovery import validate_sources_at_startup @@ -42,6 +52,9 @@ def create_app() -> FastAPI: # Validate source configs at startup (fail fast with clear errors) validate_sources_at_startup(config.sources) + # Validate all handlers have authorization declarations (fail fast) + validate_all_handlers() + app_instance = FastAPI( title=config.server.name, description=config.server.description, @@ -59,6 +72,7 @@ def create_app() -> FastAPI: # Register v1 routes with /api/v1 prefix app_instance.include_router(health.router, prefix="/api/v1") + app_instance.include_router(admin.router, prefix="/api/v1") app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") app_instance.include_router(records.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/errors.py b/server/osa/application/api/v1/errors.py index 00a31ec..8a9d4cf 100644 --- a/server/osa/application/api/v1/errors.py +++ b/server/osa/application/api/v1/errors.py @@ -49,6 +49,13 @@ def map_osa_error(error: OSAError) -> HTTPException: status_code = DOMAIN_ERROR_STATUS_MAP.get(type(error), 400) if isinstance(error, ValidationError) and error.field is not None: detail["field"] = error.field + # Distinguish 401 (unauthenticated) from 403 (unauthorized) + if isinstance(error, AuthorizationError) and error.code == "missing_token": + return HTTPException( + status_code=401, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) return HTTPException(status_code=status_code, detail=detail) # Fallback for unknown OSAError subclasses diff --git a/server/osa/application/api/v1/routes/admin.py b/server/osa/application/api/v1/routes/admin.py new file mode 100644 index 0000000..66c112a --- /dev/null +++ b/server/osa/application/api/v1/routes/admin.py @@ -0,0 +1,95 @@ +"""Admin routes for role management.""" + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter, Response +from pydantic import BaseModel + +from osa.domain.auth.command.assign_role import ( + AssignRole, + AssignRoleHandler, +) +from osa.domain.auth.command.revoke_role import ( + RevokeRole, + RevokeRoleHandler, +) +from osa.domain.auth.query.get_user_roles import ( + GetUserRoles, + GetUserRolesHandler, +) + +router = APIRouter(prefix="/admin", tags=["Admin"], route_class=DishkaRoute) + + +class AssignRoleRequest(BaseModel): + """Request body for assigning a role.""" + + role: str + + +class RoleAssignmentResponse(BaseModel): + """Response for a single role assignment.""" + + id: str + user_id: str + role: str + assigned_by: str + assigned_at: str + + +class RoleAssignmentListResponse(BaseModel): + """Response listing role assignments.""" + + roles: list[RoleAssignmentResponse] + + +@router.get("/users/{user_id}/roles", response_model=RoleAssignmentListResponse) +async def list_user_roles( + user_id: str, + handler: FromDishka[GetUserRolesHandler], +) -> RoleAssignmentListResponse: + """List all roles assigned to a user. Requires SuperAdmin role.""" + result = await handler.run(GetUserRoles(user_id=user_id)) + return RoleAssignmentListResponse( + roles=[ + RoleAssignmentResponse( + id=r.id, + user_id=r.user_id, + role=r.role, + assigned_by=r.assigned_by, + assigned_at=r.assigned_at.isoformat(), + ) + for r in result.roles + ] + ) + + +@router.post( + "/users/{user_id}/roles", + response_model=RoleAssignmentResponse, + status_code=201, +) +async def assign_role( + user_id: str, + body: AssignRoleRequest, + handler: FromDishka[AssignRoleHandler], +) -> RoleAssignmentResponse: + """Assign a role to a user. Requires SuperAdmin role.""" + result = await handler.run(AssignRole(user_id=user_id, role=body.role)) + return RoleAssignmentResponse( + id=result.id, + user_id=result.user_id, + role=result.role, + assigned_by=result.assigned_by, + assigned_at=result.assigned_at.isoformat(), + ) + + +@router.delete("/users/{user_id}/roles/{role}", status_code=204) +async def revoke_role( + user_id: str, + role: str, + handler: FromDishka[RevokeRoleHandler], +) -> Response: + """Revoke a role from a user. Requires SuperAdmin role.""" + await handler.run(RevokeRole(user_id=user_id, role=role)) + return Response(status_code=204) diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py index de22b0d..d7de536 100644 --- a/server/osa/application/api/v1/routes/auth.py +++ b/server/osa/application/api/v1/routes/auth.py @@ -25,6 +25,7 @@ ) from osa.domain.auth.model.value import CurrentUser from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService from osa.domain.shared.error import InvalidStateError @@ -62,12 +63,13 @@ class LogoutResponse(BaseModel): class UserResponse(BaseModel): - """Response containing user info.""" + """Response containing user info with roles.""" id: str display_name: str | None provider: str external_id: str + roles: list[str] @router.get("/login") @@ -259,8 +261,9 @@ async def logout( async def get_me( current_user: FromDishka[CurrentUser], auth_service: FromDishka[AuthService], + role_repo: FromDishka[RoleAssignmentRepository], ) -> UserResponse: - """Get current authenticated user information.""" + """Get current authenticated user information with roles.""" user = await auth_service.get_user_by_id(current_user.user_id) if user is None: @@ -269,9 +272,13 @@ async def get_me( detail={"code": "user_not_found", "message": "User not found"}, ) + assignments = await role_repo.get_by_user_id(current_user.user_id) + roles = [a.role.name.lower() for a in assignments] + return UserResponse( id=str(user.id), display_name=user.display_name, provider=current_user.identity.provider, external_id=current_user.identity.external_id, + roles=roles, ) diff --git a/server/osa/domain/auth/command/assign_role.py b/server/osa/domain/auth/command/assign_role.py new file mode 100644 index 0000000..7be4b43 --- /dev/null +++ b/server/osa/domain/auth/command/assign_role.py @@ -0,0 +1,51 @@ +"""AssignRole command and handler.""" + +from datetime import datetime +from uuid import UUID + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.command import Command, CommandHandler, Result + + +class AssignRole(Command): + """Command to assign a role to a user.""" + + user_id: str # UUID as string from API + role: str # Role name from API + + +class AssignRoleResult(Result): + """Result containing the created role assignment.""" + + id: str + user_id: str + role: str + assigned_by: str + assigned_at: datetime + + +class AssignRoleHandler(CommandHandler[AssignRole, AssignRoleResult]): + __auth__ = requires_role(Role.SUPERADMIN) + _principal: Principal | None = None + authorization_service: AuthorizationService + + async def run(self, cmd: AssignRole) -> AssignRoleResult: + assert self._principal is not None # Guaranteed by __auth__ gate + + assignment = await self.authorization_service.assign_role( + user_id=UserId(UUID(cmd.user_id)), + role=Role[cmd.role.upper()], + assigned_by=self._principal.user_id, + ) + + return AssignRoleResult( + id=str(assignment.id), + user_id=str(assignment.user_id), + role=assignment.role.name.lower(), + assigned_by=str(assignment.assigned_by), + assigned_at=assignment.assigned_at, + ) diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py index a54d423..de46dc3 100644 --- a/server/osa/domain/auth/command/login.py +++ b/server/osa/domain/auth/command/login.py @@ -7,6 +7,8 @@ from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService +from typing import ClassVar + from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventId @@ -16,6 +18,8 @@ class InitiateLogin(Command): """Command to start OAuth login flow.""" + __public__: ClassVar[bool] = True + callback_url: str # OAuth callback URL (where IdP redirects after auth) final_redirect_uri: str # Where to redirect user after OAuth completes provider: str @@ -59,6 +63,8 @@ async def run(self, cmd: InitiateLogin) -> InitiateLoginResult: class CompleteOAuth(Command): """Command to complete OAuth flow with authorization code.""" + __public__: ClassVar[bool] = True + code: str callback_url: str # Must match the one used in authorization provider: str # The identity provider name (from verified state) diff --git a/server/osa/domain/auth/command/revoke_role.py b/server/osa/domain/auth/command/revoke_role.py new file mode 100644 index 0000000..4534bd7 --- /dev/null +++ b/server/osa/domain/auth/command/revoke_role.py @@ -0,0 +1,36 @@ +"""RevokeRole command and handler.""" + +from uuid import UUID + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.command import Command, CommandHandler, Result + + +class RevokeRole(Command): + """Command to revoke a role from a user.""" + + user_id: str # UUID as string from API + role: str # Role name from API + + +class RevokeRoleResult(Result): + """Empty result for successful revocation.""" + + pass + + +class RevokeRoleHandler(CommandHandler[RevokeRole, RevokeRoleResult]): + __auth__ = requires_role(Role.SUPERADMIN) + _principal: Principal | None = None + authorization_service: AuthorizationService + + async def run(self, cmd: RevokeRole) -> RevokeRoleResult: + await self.authorization_service.revoke_role( + user_id=UserId(UUID(cmd.user_id)), + role=Role[cmd.role.upper()], + ) + return RevokeRoleResult() diff --git a/server/osa/domain/auth/command/token.py b/server/osa/domain/auth/command/token.py index 876c800..dc0fd27 100644 --- a/server/osa/domain/auth/command/token.py +++ b/server/osa/domain/auth/command/token.py @@ -1,6 +1,7 @@ """Token commands for refresh and logout operations.""" from dataclasses import dataclass +from typing import ClassVar from uuid import uuid4 from osa.domain.auth.event import UserLoggedOut @@ -14,6 +15,8 @@ class RefreshTokens(Command): """Command to refresh access token using refresh token.""" + __public__: ClassVar[bool] = True + refresh_token: str @@ -48,6 +51,8 @@ async def run(self, cmd: RefreshTokens) -> RefreshTokensResult: class Logout(Command): """Command to logout and revoke refresh token family.""" + __public__: ClassVar[bool] = True + refresh_token: str diff --git a/server/osa/domain/auth/model/principal.py b/server/osa/domain/auth/model/principal.py new file mode 100644 index 0000000..c2bac92 --- /dev/null +++ b/server/osa/domain/auth/model/principal.py @@ -0,0 +1,26 @@ +"""Principal — authenticated identity with roles, resolved per-request.""" + +from dataclasses import dataclass + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId + + +@dataclass(frozen=True) +class Principal: + """The authenticated identity of the current requester. + + Resolved per-request from JWT + role lookup. Immutable after creation. + """ + + user_id: UserId + identity: ProviderIdentity + roles: frozenset[Role] + + def has_role(self, role: Role) -> bool: + """Check if any assigned role >= the given role (hierarchy comparison).""" + return any(r >= role for r in self.roles) + + def has_any_role(self, *roles: Role) -> bool: + """Check if any assigned role satisfies any of the given roles.""" + return any(self.has_role(r) for r in roles) diff --git a/server/osa/domain/auth/model/role.py b/server/osa/domain/auth/model/role.py new file mode 100644 index 0000000..8c1e105 --- /dev/null +++ b/server/osa/domain/auth/model/role.py @@ -0,0 +1,17 @@ +"""Role hierarchy for authorization.""" + +from enum import IntEnum + + +class Role(IntEnum): + """Hierarchical roles with numeric ordering. + + Higher values inherit all permissions of lower values. + Gaps allow future role insertion without renumbering. + """ + + PUBLIC = 0 + DEPOSITOR = 10 + CURATOR = 20 + ADMIN = 30 + SUPERADMIN = 40 diff --git a/server/osa/domain/auth/model/role_assignment.py b/server/osa/domain/auth/model/role_assignment.py new file mode 100644 index 0000000..72f6971 --- /dev/null +++ b/server/osa/domain/auth/model/role_assignment.py @@ -0,0 +1,49 @@ +"""RoleAssignment entity — tracks user-role associations.""" + +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import RootModel + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.shared.model.entity import Entity + + +class RoleAssignmentId(RootModel[UUID]): + """Unique identifier for a RoleAssignment.""" + + @classmethod + def generate(cls) -> "RoleAssignmentId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class RoleAssignment(Entity): + """Association between a user and a role, managed by superadmins.""" + + id: RoleAssignmentId + user_id: UserId + role: Role + assigned_by: UserId + assigned_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + role: Role, + assigned_by: UserId, + ) -> "RoleAssignment": + return cls( + id=RoleAssignmentId.generate(), + user_id=user_id, + role=role, + assigned_by=assigned_by, + assigned_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/port/role_repository.py b/server/osa/domain/auth/port/role_repository.py new file mode 100644 index 0000000..88b7eff --- /dev/null +++ b/server/osa/domain/auth/port/role_repository.py @@ -0,0 +1,33 @@ +"""Repository port for RoleAssignment persistence.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment +from osa.domain.auth.model.value import UserId +from osa.domain.shared.port import Port + + +class RoleAssignmentRepository(Port, Protocol): + """Repository for RoleAssignment entity persistence.""" + + @abstractmethod + async def get_by_user_id(self, user_id: UserId) -> list[RoleAssignment]: + """Get all role assignments for a user.""" + ... + + @abstractmethod + async def save(self, assignment: RoleAssignment) -> None: + """Save a role assignment.""" + ... + + @abstractmethod + async def delete(self, user_id: UserId, role: Role) -> bool: + """Delete a role assignment. Returns True if deleted, False if not found.""" + ... + + @abstractmethod + async def get(self, user_id: UserId, role: Role) -> RoleAssignment | None: + """Get a specific role assignment.""" + ... diff --git a/server/osa/domain/auth/query/get_user_roles.py b/server/osa/domain/auth/query/get_user_roles.py new file mode 100644 index 0000000..c51bd21 --- /dev/null +++ b/server/osa/domain/auth/query/get_user_roles.py @@ -0,0 +1,56 @@ +"""GetUserRoles query and handler.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +class GetUserRoles(Query): + """Query to get all roles assigned to a user.""" + + user_id: str # UUID as string from API + + +class RoleAssignmentDTO(BaseModel): + id: str + user_id: str + role: str + assigned_by: str + assigned_at: datetime + + +class GetUserRolesResult(QueryResult): + roles: list[RoleAssignmentDTO] + + +class GetUserRolesHandler(QueryHandler[GetUserRoles, GetUserRolesResult]): + __auth__ = requires_role(Role.SUPERADMIN) + _principal: Principal | None = None + authorization_service: AuthorizationService + + async def run(self, cmd: GetUserRoles) -> GetUserRolesResult: + assignments = await self.authorization_service.list_roles( + user_id=UserId(UUID(cmd.user_id)), + ) + + return GetUserRolesResult( + roles=[ + RoleAssignmentDTO( + id=str(a.id), + user_id=str(a.user_id), + role=a.role.name.lower(), + assigned_by=str(a.assigned_by), + assigned_at=a.assigned_at, + ) + for a in assignments + ] + ) diff --git a/server/osa/domain/auth/service/authorization.py b/server/osa/domain/auth/service/authorization.py new file mode 100644 index 0000000..4bde1a8 --- /dev/null +++ b/server/osa/domain/auth/service/authorization.py @@ -0,0 +1,49 @@ +"""Authorization service — role assignment management.""" + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment +from osa.domain.auth.model.value import UserId +from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.service import Service + + +class AuthorizationService(Service): + """Manages role assignments for users.""" + + _role_repo: RoleAssignmentRepository + + async def assign_role( + self, + user_id: UserId, + role: Role, + assigned_by: UserId, + ) -> RoleAssignment: + """Assign a role to a user. Raises ConflictError if already assigned.""" + existing = await self._role_repo.get(user_id, role) + if existing is not None: + raise ConflictError( + f"Role {role.name} already assigned to user {user_id}", + code="role_already_assigned", + ) + + assignment = RoleAssignment.create( + user_id=user_id, + role=role, + assigned_by=assigned_by, + ) + await self._role_repo.save(assignment) + return assignment + + async def revoke_role(self, user_id: UserId, role: Role) -> None: + """Revoke a role from a user. Raises NotFoundError if not assigned.""" + deleted = await self._role_repo.delete(user_id, role) + if not deleted: + raise NotFoundError( + f"Role {role.name} not assigned to user {user_id}", + code="role_not_found", + ) + + async def list_roles(self, user_id: UserId) -> list[RoleAssignment]: + """List all role assignments for a user.""" + return await self._role_repo.get_by_user_id(user_id) diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py index 8cce9b2..846d746 100644 --- a/server/osa/domain/auth/util/di/provider.py +++ b/server/osa/domain/auth/util/di/provider.py @@ -1,5 +1,6 @@ """DI provider for auth domain.""" +import logging from uuid import UUID import jwt @@ -8,23 +9,32 @@ from starlette.requests import Request from osa.config import Config +from osa.domain.auth.command.assign_role import AssignRoleHandler from osa.domain.auth.command.login import ( CompleteOAuthHandler, InitiateLoginHandler, ) +from osa.domain.auth.command.revoke_role import RevokeRoleHandler from osa.domain.auth.command.token import LogoutHandler, RefreshTokensHandler +from osa.domain.auth.query.get_user_roles import GetUserRolesHandler +from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId from osa.domain.auth.port.repository import ( IdentityRepository, RefreshTokenRepository, UserRepository, ) +from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.authorization import AuthorizationService from osa.domain.auth.service.token import TokenService +from osa.domain.shared.authorization.policy_set import POLICY_SET, PolicySet from osa.domain.shared.outbox import Outbox from osa.util.di.base import Provider from osa.util.di.scope import Scope +logger = logging.getLogger(__name__) + class AuthProvider(Provider): """DI provider for auth domain services and handlers.""" @@ -36,6 +46,14 @@ class AuthProvider(Provider): complete_oauth_handler = provide(CompleteOAuthHandler, scope=Scope.UOW) refresh_tokens_handler = provide(RefreshTokensHandler, scope=Scope.UOW) logout_handler = provide(LogoutHandler, scope=Scope.UOW) + assign_role_handler = provide(AssignRoleHandler, scope=Scope.UOW) + revoke_role_handler = provide(RevokeRoleHandler, scope=Scope.UOW) + + # Query Handlers + get_user_roles_handler = provide(GetUserRolesHandler, scope=Scope.UOW) + + # Services + authorization_service = provide(AuthorizationService, scope=Scope.UOW) @provide(scope=Scope.UOW) def get_token_service(self, config: Config) -> TokenService: @@ -102,3 +120,47 @@ def get_current_user( detail={"code": "invalid_token", "message": "Invalid token"}, headers={"WWW-Authenticate": "Bearer"}, ) from e + + @provide(scope=Scope.UOW) + async def get_principal( + self, + request: Request, + token_service: TokenService, + role_repo: RoleAssignmentRepository, + ) -> Principal | None: + """Resolve Principal from JWT + role lookup. + + Returns None for anonymous requests (no JWT / invalid JWT). + This allows public endpoints to work without authentication. + """ + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] # Remove "Bearer " prefix + + try: + payload = token_service.validate_access_token(token) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return None + + user_id = UserId(UUID(payload["sub"])) + + # Lookup roles from DB + assignments = await role_repo.get_by_user_id(user_id) + roles = frozenset(a.role for a in assignments) + + return Principal( + user_id=user_id, + identity=ProviderIdentity( + provider=payload["provider"], + external_id=payload["external_id"], + ), + roles=roles, + ) + + @provide(scope=Scope.APP) + def get_policy_set(self) -> PolicySet: + """Provide the global PolicySet singleton. Validates coverage at startup.""" + POLICY_SET.validate_coverage() + return POLICY_SET diff --git a/server/osa/domain/deposition/command/create.py b/server/osa/domain/deposition/command/create.py index 0fd73e1..415d67c 100644 --- a/server/osa/domain/deposition/command/create.py +++ b/server/osa/domain/deposition/command/create.py @@ -2,7 +2,10 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.policy import requires_role from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -16,6 +19,8 @@ class DepositionCreated(Result): class CreateDepositionHandler(CommandHandler[CreateDeposition, DepositionCreated]): + __auth__ = requires_role(Role.DEPOSITOR) + _principal: Principal | None = None deposition_service: DepositionService async def run(self, cmd: CreateDeposition) -> DepositionCreated: diff --git a/server/osa/domain/deposition/command/delete_files.py b/server/osa/domain/deposition/command/delete_files.py index 675c90f..87be419 100644 --- a/server/osa/domain/deposition/command/delete_files.py +++ b/server/osa/domain/deposition/command/delete_files.py @@ -1,6 +1,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.port import DepositionRepository, StoragePort +from osa.domain.shared.authorization.policy import requires_role from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -14,6 +17,8 @@ class DepositionFilesDeleted(Result): class DeleteDepositionFilesHandler(CommandHandler[DeleteDepositionFiles, DepositionFilesDeleted]): + __auth__ = requires_role(Role.DEPOSITOR) + _principal: Principal | None = None repository: DepositionRepository storage: StoragePort diff --git a/server/osa/domain/deposition/command/submit.py b/server/osa/domain/deposition/command/submit.py index d7412f4..8a9f9d6 100644 --- a/server/osa/domain/deposition/command/submit.py +++ b/server/osa/domain/deposition/command/submit.py @@ -1,8 +1,11 @@ import logfire from uuid import uuid4 +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.event.submitted import DepositionSubmittedEvent from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.policy import requires_role from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.event import EventId from osa.domain.shared.model.srn import DepositionSRN @@ -18,6 +21,8 @@ class DepositionSubmitted(Result): class SubmitDepositionHandler(CommandHandler[SubmitDeposition, DepositionSubmitted]): + __auth__ = requires_role(Role.DEPOSITOR) + _principal: Principal | None = None deposition_service: DepositionService outbox: Outbox diff --git a/server/osa/domain/deposition/command/update.py b/server/osa/domain/deposition/command/update.py index 680afff..7316cf3 100644 --- a/server/osa/domain/deposition/command/update.py +++ b/server/osa/domain/deposition/command/update.py @@ -1,6 +1,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.policy import requires_role from osa.domain.shared.command import Command, CommandHandler, Result @@ -11,6 +14,8 @@ class DepositionUpdated(Result): ... class UpdateDepositionHandler(CommandHandler[UpdateDeposition, DepositionUpdated]): + __auth__ = requires_role(Role.DEPOSITOR) + _principal: Principal | None = None deposition_service: DepositionService async def run(self, cmd: UpdateDeposition) -> DepositionUpdated: diff --git a/server/osa/domain/deposition/command/upload.py b/server/osa/domain/deposition/command/upload.py index f0d0712..be29b9e 100644 --- a/server/osa/domain/deposition/command/upload.py +++ b/server/osa/domain/deposition/command/upload.py @@ -2,6 +2,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.policy import requires_role from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -17,6 +20,9 @@ class FileUploaded(Result): class UploadFileHandler(CommandHandler[UploadFile, FileUploaded]): + __auth__ = requires_role(Role.DEPOSITOR) + _principal: Principal | None = None + async def run(self, cmd: UploadFile) -> FileUploaded: with logfire.span("UploadFile"): # TODO: Implement actual file storage logic diff --git a/server/osa/domain/deposition/model/aggregate.py b/server/osa/domain/deposition/model/aggregate.py index 1cae2d9..32834fc 100644 --- a/server/osa/domain/deposition/model/aggregate.py +++ b/server/osa/domain/deposition/model/aggregate.py @@ -1,5 +1,6 @@ from typing import Any, Generic, TypeVar +from osa.domain.auth.model.value import UserId from osa.domain.deposition.model.value import DepositionFile, DepositionStatus from osa.domain.shared.model.aggregate import Aggregate from osa.domain.shared.model.srn import DepositionSRN, RecordSRN @@ -14,6 +15,7 @@ class Deposition(Aggregate, Generic[T]): files: list[DepositionFile] = [] record_srn: RecordSRN | None = None provenance: dict[str, Any] = {} # Source info, provenance tracking + owner_id: UserId | None = None def remove_all_files(self) -> None: self.files = [] diff --git a/server/osa/domain/shared/authorization/__init__.py b/server/osa/domain/shared/authorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/shared/authorization/action.py b/server/osa/domain/shared/authorization/action.py new file mode 100644 index 0000000..6b04841 --- /dev/null +++ b/server/osa/domain/shared/authorization/action.py @@ -0,0 +1,57 @@ +"""Authorization actions — all operations subject to access control.""" + +from enum import StrEnum + + +class Action(StrEnum): + """Structured enum of all authorization-relevant operations.""" + + # Depositions + DEPOSITION_CREATE = "deposition:create" + DEPOSITION_READ = "deposition:read" + DEPOSITION_UPDATE = "deposition:update" + DEPOSITION_SUBMIT = "deposition:submit" + DEPOSITION_DELETE = "deposition:delete" + + # Curation + DEPOSITION_APPROVE = "deposition:approve" + DEPOSITION_REJECT = "deposition:reject" + + # Registry — Schemas + SCHEMA_READ = "schema:read" + SCHEMA_CREATE = "schema:create" + SCHEMA_UPDATE = "schema:update" + SCHEMA_DELETE = "schema:delete" + + # Registry — Traits + TRAIT_READ = "trait:read" + TRAIT_CREATE = "trait:create" + TRAIT_UPDATE = "trait:update" + TRAIT_DELETE = "trait:delete" + + # Registry — Conventions + CONVENTION_READ = "convention:read" + CONVENTION_CREATE = "convention:create" + CONVENTION_UPDATE = "convention:update" + CONVENTION_DELETE = "convention:delete" + + # Registry — Vocabularies + VOCABULARY_READ = "vocabulary:read" + VOCABULARY_CREATE = "vocabulary:create" + VOCABULARY_UPDATE = "vocabulary:update" + VOCABULARY_DELETE = "vocabulary:delete" + + # Records (read-only after publication) + RECORD_READ = "record:read" + + # Search + SEARCH_QUERY = "search:query" + + # Validation + VALIDATION_CREATE = "validation:create" + VALIDATION_READ = "validation:read" + + # Administration + ROLE_ASSIGN = "role:assign" + ROLE_REVOKE = "role:revoke" + ROLE_READ = "role:read" diff --git a/server/osa/domain/shared/authorization/authorized_repo.py b/server/osa/domain/shared/authorization/authorized_repo.py new file mode 100644 index 0000000..12a1025 --- /dev/null +++ b/server/osa/domain/shared/authorization/authorized_repo.py @@ -0,0 +1,40 @@ +"""AuthorizedRepo — wraps a raw repository, returns Guarded[T] from get().""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Generic, TypeVar + +from osa.domain.shared.authorization.guarded import Guarded +from osa.domain.shared.error import NotFoundError + +if TYPE_CHECKING: + from osa.domain.auth.model.principal import Principal + from osa.domain.shared.authorization.policy_set import PolicySet + +T = TypeVar("T") +ID = TypeVar("ID") + + +class AuthorizedRepo(Generic[T, ID]): + """Wraps a raw repository and returns Guarded[T] from get(). + + Used by services that need to enforce authorization on loaded resources. + Event handlers and background workers should use the raw repository directly. + """ + + def __init__( + self, + inner: Any, + principal: "Principal", + policy_set: "PolicySet", + ) -> None: + self._inner = inner + self._principal = principal + self._policy_set = policy_set + + async def get(self, id: ID) -> Guarded[T]: + """Load a resource and wrap it in Guarded[T].""" + resource = await self._inner.get(id) + if resource is None: + raise NotFoundError(f"Resource not found: {id}") + return Guarded(resource, self._principal, self._policy_set) diff --git a/server/osa/domain/shared/authorization/guarded.py b/server/osa/domain/shared/authorization/guarded.py new file mode 100644 index 0000000..06f87d6 --- /dev/null +++ b/server/osa/domain/shared/authorization/guarded.py @@ -0,0 +1,41 @@ +"""Guarded[T] — generic wrapper forcing explicit authorization check.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +from osa.domain.shared.authorization.action import Action + +if TYPE_CHECKING: + from osa.domain.auth.model.principal import Principal + from osa.domain.shared.authorization.policy_set import PolicySet + +T = TypeVar("T") + + +class Guarded(Generic[T]): + """Wraps a loaded domain resource, forcing an explicit authorization check. + + The ONLY way to access the inner resource is via `.check(action)`. + No attribute proxy — accessing attributes on Guarded raises AttributeError. + """ + + __slots__ = ("_resource", "_principal", "_policy_set") + + def __init__( + self, + resource: T, + principal: Principal, + policy_set: PolicySet, + ) -> None: + self._resource = resource + self._principal = principal + self._policy_set = policy_set + + def check(self, action: Action) -> T: + """Evaluate authorization and return the unwrapped resource. + + Raises AuthorizationError if access is denied. + """ + self._policy_set.guard(self._principal, action, self._resource) + return self._resource diff --git a/server/osa/domain/shared/authorization/policy.py b/server/osa/domain/shared/authorization/policy.py new file mode 100644 index 0000000..946e8cd --- /dev/null +++ b/server/osa/domain/shared/authorization/policy.py @@ -0,0 +1,83 @@ +"""Composable policy types for handler-level authorization gates.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from osa.domain.auth.model.principal import Principal + from osa.domain.auth.model.role import Role + + +class Policy(ABC): + """Base class for composable authorization policies. + + Policies are evaluated at the handler level as a coarse pre-filter + (role check only, no resource loaded yet). + """ + + @abstractmethod + def evaluate(self, principal: "Principal") -> bool: + """Return True if principal satisfies this policy.""" + ... + + def __and__(self, other: Policy) -> AllOf: + return AllOf(policies=(self, other)) + + def __or__(self, other: Policy) -> AnyOf: + return AnyOf(policies=(self, other)) + + def __invert__(self) -> Not: + return Not(policy=self) + + +@dataclass(frozen=True) +class RequiresRole(Policy): + """Policy that checks principal has at least the given role (hierarchy).""" + + role: "Role" + + def evaluate(self, principal: "Principal") -> bool: + return principal.has_role(self.role) + + +@dataclass(frozen=True) +class AllOf(Policy): + """Policy that requires ALL sub-policies to pass.""" + + policies: tuple[Policy, ...] + + def evaluate(self, principal: "Principal") -> bool: + return all(p.evaluate(principal) for p in self.policies) + + +@dataclass(frozen=True) +class AnyOf(Policy): + """Policy that requires at least ONE sub-policy to pass.""" + + policies: tuple[Policy, ...] + + def evaluate(self, principal: "Principal") -> bool: + return any(p.evaluate(principal) for p in self.policies) + + +@dataclass(frozen=True) +class Not(Policy): + """Policy that inverts another policy.""" + + policy: Policy + + def evaluate(self, principal: "Principal") -> bool: + return not self.policy.evaluate(principal) + + +def requires_role(role: "Role") -> RequiresRole: + """Factory: policy requiring at least the given role.""" + return RequiresRole(role=role) + + +def requires_any_role(*roles: "Role") -> AnyOf: + """Factory: policy requiring at least one of the given roles.""" + return AnyOf(policies=tuple(RequiresRole(role=r) for r in roles)) diff --git a/server/osa/domain/shared/authorization/policy_set.py b/server/osa/domain/shared/authorization/policy_set.py new file mode 100644 index 0000000..bb6e7bc --- /dev/null +++ b/server/osa/domain/shared/authorization/policy_set.py @@ -0,0 +1,164 @@ +"""PolicySet — declarative authorization rules and the Relationship enum. + +Contains PolicyRule, Relationship, allow() constructor, and the POLICY_SET constant. +This is the single source of truth for all "who can do what on which resource" rules. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import StrEnum +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from osa.domain.auth.model.principal import Principal + +from osa.domain.shared.authorization.action import Action + +logger = logging.getLogger(__name__) + + +class Relationship(StrEnum): + """Relationships between a principal and a resource.""" + + OWNER = "owner" + + +@dataclass(frozen=True) +class PolicyRule: + """A single authorization rule in the policy set.""" + + action: Action + role: "Role | None" = None + relationship: Relationship | None = None + + +def allow( + action: Action, + *, + role: "Role | None" = None, + relationship: Relationship | None = None, +) -> PolicyRule: + """Convenience constructor for a policy rule.""" + return PolicyRule(action=action, role=role, relationship=relationship) + + +# Import Role here (after PolicyRule is defined) to avoid circular imports +from osa.domain.auth.model.role import Role # noqa: E402 + + +class PolicySet: + """Declarative set of all authorization rules. + + Evaluation: for a given action, rules are tried in order. + First match wins (allow). No match means deny. + """ + + def __init__(self, rules: list[PolicyRule]) -> None: + self._rules = rules + self._by_action: dict[Action, list[PolicyRule]] = {} + for rule in rules: + self._by_action.setdefault(rule.action, []).append(rule) + + def guard( + self, + principal: "Principal | None", + action: Action, + resource: Any = None, + ) -> None: + """Raise AuthorizationError if no rule allows this access.""" + from osa.domain.shared.error import AuthorizationError + + principal_id = str(principal.user_id) if principal else "anonymous" + + rules = self._by_action.get(action, []) + for rule in rules: + if self._matches(rule, principal, resource): + logger.info( + "Authorization allowed: principal=%s action=%s", + principal_id, + action, + ) + return + + logger.warning( + "Authorization denied: principal=%s action=%s", + principal_id, + action, + ) + raise AuthorizationError(f"Access denied: {action}", code="access_denied") + + def _matches( + self, + rule: PolicyRule, + principal: "Principal | None", + resource: Any, + ) -> bool: + # Public rule (no role required) + if rule.role is None: + return True + # Must be authenticated + if principal is None: + return False + # Role hierarchy check + if not principal.has_role(rule.role): + return False + # Relationship check (if required) + if rule.relationship == Relationship.OWNER: + owner_id = getattr(resource, "owner_id", None) + if owner_id is None or owner_id != principal.user_id: + return False + return True + + def validate_coverage(self) -> None: + """Startup check: every Action enum member must have at least one rule.""" + from osa.domain.shared.error import ConfigurationError + + covered = {r.action for r in self._rules} + missing = set(Action) - covered + if missing: + raise ConfigurationError(f"Actions without policy rules: {missing}") + + +POLICY_SET = PolicySet( + [ + # Public reads (no auth required) + allow(Action.RECORD_READ), + allow(Action.SEARCH_QUERY), + allow(Action.SCHEMA_READ), + allow(Action.TRAIT_READ), + allow(Action.CONVENTION_READ), + allow(Action.VOCABULARY_READ), + allow(Action.VALIDATION_READ), + # Depositions (ownership-scoped) + allow(Action.DEPOSITION_CREATE, role=Role.DEPOSITOR), + allow(Action.DEPOSITION_READ, role=Role.DEPOSITOR, relationship=Relationship.OWNER), + allow(Action.DEPOSITION_UPDATE, role=Role.DEPOSITOR, relationship=Relationship.OWNER), + allow(Action.DEPOSITION_SUBMIT, role=Role.DEPOSITOR, relationship=Relationship.OWNER), + allow(Action.DEPOSITION_DELETE, role=Role.DEPOSITOR, relationship=Relationship.OWNER), + # Curators can read all depositions (no ownership required) + allow(Action.DEPOSITION_READ, role=Role.CURATOR), + allow(Action.DEPOSITION_APPROVE, role=Role.CURATOR), + allow(Action.DEPOSITION_REJECT, role=Role.CURATOR), + # Registry (admin-only writes) + allow(Action.SCHEMA_CREATE, role=Role.ADMIN), + allow(Action.SCHEMA_UPDATE, role=Role.ADMIN), + allow(Action.SCHEMA_DELETE, role=Role.ADMIN), + allow(Action.TRAIT_CREATE, role=Role.ADMIN), + allow(Action.TRAIT_UPDATE, role=Role.ADMIN), + allow(Action.TRAIT_DELETE, role=Role.ADMIN), + allow(Action.CONVENTION_CREATE, role=Role.ADMIN), + allow(Action.CONVENTION_UPDATE, role=Role.ADMIN), + allow(Action.CONVENTION_DELETE, role=Role.ADMIN), + allow(Action.VOCABULARY_CREATE, role=Role.ADMIN), + allow(Action.VOCABULARY_UPDATE, role=Role.ADMIN), + allow(Action.VOCABULARY_DELETE, role=Role.ADMIN), + # Validation + allow(Action.VALIDATION_CREATE, role=Role.DEPOSITOR), + # Administration (superadmin-only) + allow(Action.ROLE_ASSIGN, role=Role.SUPERADMIN), + allow(Action.ROLE_REVOKE, role=Role.SUPERADMIN), + allow(Action.ROLE_READ, role=Role.SUPERADMIN), + ] +) diff --git a/server/osa/domain/shared/authorization/startup.py b/server/osa/domain/shared/authorization/startup.py new file mode 100644 index 0000000..f179aff --- /dev/null +++ b/server/osa/domain/shared/authorization/startup.py @@ -0,0 +1,75 @@ +"""Startup validation for handler authorization declarations.""" + +import logging + +from osa.domain.shared.command import CommandHandler +from osa.domain.shared.error import ConfigurationError +from osa.domain.shared.query import QueryHandler + +logger = logging.getLogger(__name__) + + +def _get_command_or_query_type(handler_cls: type) -> type | None: + """Extract the Command/Query type from a handler's generic bases.""" + from typing import get_args, get_origin + + for base in getattr(handler_cls, "__orig_bases__", []): + origin = get_origin(base) + if origin is None: + continue + name = getattr(origin, "__name__", "") + if name in ("CommandHandler", "QueryHandler"): + args = get_args(base) + if args and isinstance(args[0], type): + return args[0] + return None + + +def _check_handler_class(handler_cls: type, dto_cls: type | None = None) -> None: + """Check a single handler class for __auth__ declaration. + + Raises ConfigurationError if the handler lacks __auth__ and its DTO is not __public__. + """ + if dto_cls is None: + dto_cls = _get_command_or_query_type(handler_cls) + + # If DTO is public, no __auth__ needed + if dto_cls is not None and getattr(dto_cls, "__public__", False): + return + + # Check for __auth__ + if not hasattr(handler_cls, "__auth__") or getattr(handler_cls, "__auth__") is None: + raise ConfigurationError( + f"Handler {handler_cls.__name__} has no __auth__ declaration " + f"and its command/query is not __public__" + ) + + +def validate_all_handlers() -> None: + """Scan all registered CommandHandler and QueryHandler subclasses. + + Raises ConfigurationError listing all handlers missing __auth__ declarations. + """ + violations: list[str] = [] + + for handler_cls in CommandHandler.__subclasses__(): + dto_cls = _get_command_or_query_type(handler_cls) + try: + _check_handler_class(handler_cls, dto_cls) + except ConfigurationError as e: + violations.append(str(e)) + + for handler_cls in QueryHandler.__subclasses__(): + dto_cls = _get_command_or_query_type(handler_cls) + try: + _check_handler_class(handler_cls, dto_cls) + except ConfigurationError as e: + violations.append(str(e)) + + if violations: + raise ConfigurationError( + f"Authorization validation failed for {len(violations)} handler(s):\n" + + "\n".join(f" - {v}" for v in violations) + ) + + logger.info("Authorization startup validation passed for all handlers") diff --git a/server/osa/domain/shared/command.py b/server/osa/domain/shared/command.py index 948acaf..d3ed8e8 100644 --- a/server/osa/domain/shared/command.py +++ b/server/osa/domain/shared/command.py @@ -1,11 +1,16 @@ +"""Command and CommandHandler base classes with authorization gate.""" + from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Coroutine from dataclasses import dataclass -from typing import Generic, TypeVar, dataclass_transform +from functools import wraps +from typing import Any, ClassVar, Generic, TypeVar, dataclass_transform from pydantic import BaseModel -class Command(BaseModel): ... +class Command(BaseModel): + __public__: ClassVar[bool] = False class Result(BaseModel): ... @@ -14,20 +19,89 @@ class Result(BaseModel): ... C = TypeVar("C", bound=Command) R = TypeVar("R", bound=Result) +# Unbound async handler method: (self, cmd) -> Coroutine -> Result +_HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] + + +def _get_command_type(cls: type) -> type[Command] | None: + """Extract the Command type C from CommandHandler[C, R] in class bases.""" + from typing import get_args, get_origin + + for base in getattr(cls, "__orig_bases__", []): + origin = get_origin(base) + if origin is not None and getattr(origin, "__name__", None) == "CommandHandler": + args = get_args(base) + if args and isinstance(args[0], type) and issubclass(args[0], Command): + return args[0] + return None + + +def _wrap_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: + """Wrap the run() method with __auth__ policy evaluation.""" + + @wraps(original_run) + async def auth_wrapped_run(self: Any, cmd: Any) -> Any: + from osa.domain.shared.error import AuthorizationError + + # Check if the command type is public + cmd_type = type(cmd) + if getattr(cmd_type, "__public__", False): + return await original_run(self, cmd) + + # Non-public: check auth + auth_policy = getattr(type(self), "__auth__", None) + if auth_policy is None: + from osa.domain.shared.error import ConfigurationError + + raise ConfigurationError( + f"Handler {type(self).__name__} has no __auth__ declaration " + f"and its command is not __public__" + ) + + principal = getattr(self, "_principal", None) + if principal is None: + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) + + if not auth_policy.evaluate(principal): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) + + return await original_run(self, cmd) + + return auth_wrapped_run + @dataclass_transform() class _CommandHandlerMeta(ABCMeta): - """Metaclass that combines ABC with auto-dataclass for subclasses.""" + """Metaclass that combines ABC with auto-dataclass and __auth__ gate for subclasses.""" - def __new__(mcs, name: str, bases: tuple, namespace: dict): + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]): cls = super().__new__(mcs, name, bases, namespace) if any(isinstance(b, mcs) for b in bases): - return dataclass(cls) + cls = dataclass(cls) + + # Wrap run() with auth gate if __auth__ is declared or command is not public + original_run = cls.__dict__.get("run") + if original_run is not None: + wrapped = _wrap_run_with_auth(cls, original_run) + cls.run = wrapped + return cls class CommandHandler(Generic[C, R], metaclass=_CommandHandlerMeta): - """Base class for command handlers. Subclasses are automatically dataclasses.""" + """Base class for command handlers. Subclasses are automatically dataclasses. + + Declare __auth__ to enforce role-based access: + class MyHandler(CommandHandler[MyCmd, MyResult]): + __auth__ = requires_role(Role.ADMIN) + _principal: Principal | None = None + """ @abstractmethod async def run(self, cmd: C) -> R: ... diff --git a/server/osa/domain/shared/query.py b/server/osa/domain/shared/query.py index 4bc9093..d4c800a 100644 --- a/server/osa/domain/shared/query.py +++ b/server/osa/domain/shared/query.py @@ -1,18 +1,94 @@ -from abc import ABC, abstractmethod -from typing import Generic, TypeVar +"""Query and QueryHandler base classes with authorization gate.""" + +from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Coroutine +from dataclasses import dataclass +from functools import wraps +from typing import Any, ClassVar, Generic, TypeVar, dataclass_transform + from pydantic import BaseModel -class Query(BaseModel, ABC): ... +class Query(BaseModel): + __public__: ClassVar[bool] = False -class Result(BaseModel, ABC): ... +class Result(BaseModel): ... C = TypeVar("C", bound=Query) R = TypeVar("R", bound=Result) +# Unbound async handler method: (self, cmd) -> Coroutine -> Result +_HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] + + +def _wrap_query_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: + """Wrap the run() method with __auth__ policy evaluation.""" + + @wraps(original_run) + async def auth_wrapped_run(self: Any, cmd: Any) -> Any: + from osa.domain.shared.error import AuthorizationError + + # Check if the query type is public + cmd_type = type(cmd) + if getattr(cmd_type, "__public__", False): + return await original_run(self, cmd) + + # Non-public: check auth + auth_policy = getattr(type(self), "__auth__", None) + if auth_policy is None: + from osa.domain.shared.error import ConfigurationError + + raise ConfigurationError( + f"Handler {type(self).__name__} has no __auth__ declaration " + f"and its query is not __public__" + ) + + principal = getattr(self, "_principal", None) + if principal is None: + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) + + if not auth_policy.evaluate(principal): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) + + return await original_run(self, cmd) + + return auth_wrapped_run + + +@dataclass_transform() +class _QueryHandlerMeta(ABCMeta): + """Metaclass that combines ABC with auto-dataclass and __auth__ gate for subclasses.""" + + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]): + cls = super().__new__(mcs, name, bases, namespace) + if any(isinstance(b, mcs) for b in bases): + cls = dataclass(cls) + + # Wrap run() with auth gate + original_run = cls.__dict__.get("run") + if original_run is not None: + wrapped = _wrap_query_run_with_auth(cls, original_run) + cls.run = wrapped + + return cls + + +class QueryHandler(Generic[C, R], metaclass=_QueryHandlerMeta): + """Base class for query handlers. Subclasses are automatically dataclasses. + + Declare __auth__ to enforce role-based access: + class MyHandler(QueryHandler[MyQuery, MyResult]): + __auth__ = requires_role(Role.ADMIN) + _principal: Principal | None = None + """ -class QueryHandler(ABC, Generic[C, R]): @abstractmethod - def run(self, cmd: C) -> R: ... + async def run(self, cmd: C) -> R: ... diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py index 40a354e..8b43724 100644 --- a/server/osa/infrastructure/auth/di.py +++ b/server/osa/infrastructure/auth/di.py @@ -11,8 +11,10 @@ RefreshTokenRepository, UserRepository, ) +from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.infrastructure.auth.orcid import OrcidIdentityProvider from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry +from osa.infrastructure.auth.role_repository import PostgresRoleAssignmentRepository from osa.infrastructure.persistence.repository.auth import ( PostgresIdentityRepository, PostgresRefreshTokenRepository, @@ -49,6 +51,11 @@ class AuthInfraProvider(Provider): scope=Scope.UOW, provides=RefreshTokenRepository, ) + role_assignment_repo = provide( + PostgresRoleAssignmentRepository, + scope=Scope.UOW, + provides=RoleAssignmentRepository, + ) @provide(scope=Scope.APP) def get_auth_http_client(self) -> httpx.AsyncClient: diff --git a/server/osa/infrastructure/auth/role_repository.py b/server/osa/infrastructure/auth/role_repository.py new file mode 100644 index 0000000..0dad962 --- /dev/null +++ b/server/osa/infrastructure/auth/role_repository.py @@ -0,0 +1,73 @@ +"""PostgreSQL implementation of RoleAssignmentRepository.""" + +from uuid import UUID + +from sqlalchemy import delete, insert, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId +from osa.domain.auth.model.value import UserId +from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.infrastructure.persistence.tables import role_assignments_table + + +def _row_to_role_assignment(row: dict) -> RoleAssignment: + """Convert a database row to a RoleAssignment model.""" + return RoleAssignment( + id=RoleAssignmentId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + role=Role[row["role"].upper()], + assigned_by=UserId(UUID(row["assigned_by"])), + assigned_at=row["assigned_at"], + ) + + +def _role_assignment_to_dict(assignment: RoleAssignment) -> dict: + """Convert a RoleAssignment model to a database row dict.""" + return { + "id": str(assignment.id), + "user_id": str(assignment.user_id), + "role": assignment.role.name.lower(), + "assigned_by": str(assignment.assigned_by), + "assigned_at": assignment.assigned_at, + } + + +class PostgresRoleAssignmentRepository(RoleAssignmentRepository): + """PostgreSQL implementation of RoleAssignmentRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get_by_user_id(self, user_id: UserId) -> list[RoleAssignment]: + stmt = select(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id) + ) + result = await self.session.execute(stmt) + rows = result.mappings().all() + return [_row_to_role_assignment(dict(row)) for row in rows] + + async def save(self, assignment: RoleAssignment) -> None: + assignment_dict = _role_assignment_to_dict(assignment) + stmt = insert(role_assignments_table).values(**assignment_dict) + await self.session.execute(stmt) + await self.session.flush() + + async def delete(self, user_id: UserId, role: Role) -> bool: + stmt = delete(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id), + role_assignments_table.c.role == role.name.lower(), + ) + result = await self.session.execute(stmt) + await self.session.flush() + return result.rowcount > 0 + + async def get(self, user_id: UserId, role: Role) -> RoleAssignment | None: + stmt = select(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id), + role_assignments_table.c.role == role.name.lower(), + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_role_assignment(dict(row)) if row else None diff --git a/server/osa/infrastructure/persistence/mappers/deposition.py b/server/osa/infrastructure/persistence/mappers/deposition.py index 5d051de..78eff36 100644 --- a/server/osa/infrastructure/persistence/mappers/deposition.py +++ b/server/osa/infrastructure/persistence/mappers/deposition.py @@ -1,5 +1,7 @@ from typing import Any +from uuid import UUID +from osa.domain.auth.model.value import UserId from osa.domain.deposition.model.aggregate import Deposition from osa.domain.deposition.model.value import DepositionFile, DepositionStatus from osa.domain.shared.model.srn import DepositionSRN, RecordSRN @@ -15,6 +17,7 @@ def row_to_deposition(row: dict[str, Any]) -> Deposition[dict[str, Any]]: files = [DepositionFile(**f) for f in files_data] record_id = row.get("record_id") + owner_id_raw = row.get("owner_id") return Deposition( srn=DepositionSRN.parse(row["srn"]), @@ -23,6 +26,7 @@ def row_to_deposition(row: dict[str, Any]) -> Deposition[dict[str, Any]]: files=files, provenance=row.get("provenance", {}), record_srn=RecordSRN.parse(record_id) if record_id else None, + owner_id=UserId(UUID(owner_id_raw)) if owner_id_raw else None, ) @@ -35,4 +39,5 @@ def deposition_to_dict(dep: Deposition) -> dict[str, Any]: "files": [f.model_dump(mode="json") for f in dep.files], "provenance": dep.provenance, "record_id": str(dep.record_srn) if dep.record_srn else None, + "owner_id": str(dep.owner_id) if dep.owner_id else None, } diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index f22ee64..240ca7a 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -30,11 +30,13 @@ Column("provenance", JSON, nullable=False), Column("files", JSON, nullable=False), Column("record_id", String, nullable=True), + Column("owner_id", String, ForeignKey("users.id"), nullable=True), Column("created_at", DateTime(timezone=True), nullable=False), Column("updated_at", DateTime(timezone=True), nullable=False), ) Index("idx_depositions_record_id", depositions_table.c.record_id) +Index("idx_depositions_owner_id", depositions_table.c.owner_id) # ============================================================================ @@ -175,3 +177,20 @@ Index("ix_refresh_tokens_user_id", refresh_tokens_table.c.user_id) Index("ix_refresh_tokens_token_hash", refresh_tokens_table.c.token_hash) Index("ix_refresh_tokens_family_id", refresh_tokens_table.c.family_id) + + +# ============================================================================ +# ROLE ASSIGNMENTS TABLE (Authorization) +# ============================================================================ +role_assignments_table = Table( + "role_assignments", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("role", String(32), nullable=False), + Column("assigned_by", String, ForeignKey("users.id"), nullable=False), + Column("assigned_at", DateTime(timezone=True), nullable=False), + UniqueConstraint("user_id", "role", name="uq_role_assignments_user_role"), +) + +Index("ix_role_assignments_user_id", role_assignments_table.c.user_id) diff --git a/server/tests/unit/domain/deposition/__init__.py b/server/tests/unit/domain/deposition/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/deposition/test_deposition_service_auth.py b/server/tests/unit/domain/deposition/test_deposition_service_auth.py new file mode 100644 index 0000000..355b751 --- /dev/null +++ b/server/tests/unit/domain/deposition/test_deposition_service_auth.py @@ -0,0 +1,131 @@ +"""Tests for DepositionService authorization — T029. + +Tests that: +- Owner can read/update/submit their own deposition via Guarded[Deposition] +- Non-owner depositor is denied access to another's deposition +- owner_id is set from principal at creation time +""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.deposition.model.aggregate import Deposition +from osa.domain.deposition.model.value import DepositionStatus +from osa.domain.shared.authorization.action import Action +from osa.domain.shared.authorization.guarded import Guarded +from osa.domain.shared.authorization.policy_set import POLICY_SET +from osa.domain.shared.error import AuthorizationError +from osa.domain.shared.model.srn import DepositionSRN + + +def _make_principal( + user_id: UserId | None = None, roles: frozenset[Role] | None = None +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles or frozenset({Role.DEPOSITOR}), + ) + + +def _make_deposition(owner_id: UserId) -> Deposition: + return Deposition( + srn=DepositionSRN.parse("urn:osa:localhost:dep:00000000-0000-0000-0000-000000000001"), + status=DepositionStatus.DRAFT, + metadata={}, + owner_id=owner_id, + ) + + +class TestDepositionOwnership: + """Guarded[Deposition] enforces ownership rules via POLICY_SET.""" + + def test_owner_can_read_own_deposition(self) -> None: + owner = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, owner, POLICY_SET) + + result = guarded.check(Action.DEPOSITION_READ) + assert result is dep + + def test_owner_can_update_own_deposition(self) -> None: + owner = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, owner, POLICY_SET) + + result = guarded.check(Action.DEPOSITION_UPDATE) + assert result is dep + + def test_owner_can_submit_own_deposition(self) -> None: + owner = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, owner, POLICY_SET) + + result = guarded.check(Action.DEPOSITION_SUBMIT) + assert result is dep + + def test_non_owner_depositor_cannot_read_others_deposition(self) -> None: + owner = _make_principal() + other = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, other, POLICY_SET) + + with pytest.raises(AuthorizationError): + guarded.check(Action.DEPOSITION_READ) + + def test_non_owner_depositor_cannot_update_others_deposition(self) -> None: + owner = _make_principal() + other = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, other, POLICY_SET) + + with pytest.raises(AuthorizationError): + guarded.check(Action.DEPOSITION_UPDATE) + + def test_non_owner_depositor_cannot_submit_others_deposition(self) -> None: + owner = _make_principal() + other = _make_principal() + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, other, POLICY_SET) + + with pytest.raises(AuthorizationError): + guarded.check(Action.DEPOSITION_SUBMIT) + + def test_curator_can_read_any_deposition(self) -> None: + """Curators can read all depositions without ownership.""" + owner = _make_principal() + curator = _make_principal(roles=frozenset({Role.CURATOR})) + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, curator, POLICY_SET) + + result = guarded.check(Action.DEPOSITION_READ) + assert result is dep + + def test_admin_can_read_any_deposition(self) -> None: + """Admins inherit curator permissions via role hierarchy.""" + owner = _make_principal() + admin = _make_principal(roles=frozenset({Role.ADMIN})) + dep = _make_deposition(owner_id=owner.user_id) + guarded = Guarded(dep, admin, POLICY_SET) + + result = guarded.check(Action.DEPOSITION_READ) + assert result is dep + + +class TestDepositionOwnerIdAssignment: + """Deposition aggregate tracks owner_id.""" + + def test_deposition_has_owner_id_field(self) -> None: + owner_id = UserId.generate() + dep = _make_deposition(owner_id=owner_id) + assert dep.owner_id == owner_id + + def test_deposition_owner_id_defaults_to_none(self) -> None: + dep = Deposition( + srn=DepositionSRN.parse("urn:osa:localhost:dep:00000000-0000-0000-0000-000000000001"), + status=DepositionStatus.DRAFT, + metadata={}, + ) + assert dep.owner_id is None diff --git a/server/tests/unit/domain/shared/authorization/__init__.py b/server/tests/unit/domain/shared/authorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/shared/authorization/test_auth_gate.py b/server/tests/unit/domain/shared/authorization/test_auth_gate.py new file mode 100644 index 0000000..cdd1763 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_auth_gate.py @@ -0,0 +1,141 @@ +"""Tests for handler __auth__ gate: T013 — metaclass wraps run() with auth check.""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import AuthorizationError, ConfigurationError +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +def _make_principal(roles: frozenset[Role]) -> Principal: + return Principal( + user_id=UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +# --- Test command DTOs --- + + +class AdminOnlyCommand(Command): + value: str = "test" + + +class AdminOnlyResult(Result): + value: str + + +class PublicCommand(Command): + __public__: bool = True # ClassVar-like, signals public access + value: str = "test" + + +class PublicResult(Result): + value: str + + +# --- Test handlers --- + + +class AdminOnlyHandler(CommandHandler[AdminOnlyCommand, AdminOnlyResult]): + __auth__ = requires_role(Role.ADMIN) + _principal: Principal | None = None + + async def run(self, cmd: AdminOnlyCommand) -> AdminOnlyResult: + return AdminOnlyResult(value=cmd.value) + + +class PublicHandler(CommandHandler[PublicCommand, PublicResult]): + _principal: Principal | None = None + + async def run(self, cmd: PublicCommand) -> PublicResult: + return PublicResult(value=cmd.value) + + +class UnprotectedCommand(Command): + value: str = "test" + + +class UnprotectedResult(Result): + value: str + + +class UnprotectedHandler(CommandHandler[UnprotectedCommand, UnprotectedResult]): + async def run(self, cmd: UnprotectedCommand) -> UnprotectedResult: + return UnprotectedResult(value=cmd.value) + + +class UnprotectedQuery(Query): + value: str = "test" + + +class UnprotectedQueryResult(QueryResult): + value: str + + +class UnprotectedQueryHandler(QueryHandler[UnprotectedQuery, UnprotectedQueryResult]): + async def run(self, cmd: UnprotectedQuery) -> UnprotectedQueryResult: + return UnprotectedQueryResult(value=cmd.value) + + +# --- Tests --- + + +class TestAuthGateOnCommandHandler: + @pytest.mark.asyncio + async def test_admin_handler_rejects_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = AdminOnlyHandler(_principal=depositor) + + with pytest.raises(AuthorizationError): + await handler.run(AdminOnlyCommand(value="test")) + + @pytest.mark.asyncio + async def test_admin_handler_allows_admin(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + handler = AdminOnlyHandler(_principal=admin) + + result = await handler.run(AdminOnlyCommand(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_admin_handler_rejects_none_principal(self) -> None: + handler = AdminOnlyHandler(_principal=None) + + with pytest.raises(AuthorizationError): + await handler.run(AdminOnlyCommand(value="test")) + + @pytest.mark.asyncio + async def test_public_handler_skips_check(self) -> None: + handler = PublicHandler(_principal=None) + + result = await handler.run(PublicCommand(value="public")) + assert result.value == "public" + + @pytest.mark.asyncio + async def test_public_handler_works_with_principal(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = PublicHandler(_principal=depositor) + + result = await handler.run(PublicCommand(value="public")) + assert result.value == "public" + + @pytest.mark.asyncio + async def test_unprotected_command_handler_raises_configuration_error(self) -> None: + handler = UnprotectedHandler() + + with pytest.raises(ConfigurationError, match="UnprotectedHandler"): + await handler.run(UnprotectedCommand(value="test")) + + @pytest.mark.asyncio + async def test_unprotected_query_handler_raises_configuration_error(self) -> None: + handler = UnprotectedQueryHandler() + + with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): + await handler.run(UnprotectedQuery(value="test")) diff --git a/server/tests/unit/domain/shared/authorization/test_authorization_audit.py b/server/tests/unit/domain/shared/authorization/test_authorization_audit.py new file mode 100644 index 0000000..afb00cd --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_authorization_audit.py @@ -0,0 +1,63 @@ +"""Tests for authorization audit logging — T058. + +Tests that PolicySet.guard() emits structured log entries for allow and deny decisions. +""" + +import logging + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.action import Action +from osa.domain.shared.authorization.policy_set import POLICY_SET +from osa.domain.shared.error import AuthorizationError + + +def _make_principal(roles: frozenset[Role], user_id: UserId | None = None) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +class TestAuthorizationAuditLogging: + def test_guard_logs_allow(self, caplog: pytest.LogCaptureFixture) -> None: + """Successful authorization should emit an info-level log.""" + principal = _make_principal(frozenset({Role.ADMIN})) + + with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): + POLICY_SET.guard(principal, Action.SCHEMA_CREATE) + + # Should have an allow log entry + allow_messages = [r for r in caplog.records if "allowed" in r.message.lower()] + assert len(allow_messages) >= 1 + record = allow_messages[0] + assert str(principal.user_id) in record.message + assert Action.SCHEMA_CREATE in record.message + + def test_guard_logs_deny(self, caplog: pytest.LogCaptureFixture) -> None: + """Failed authorization should emit a warning-level log.""" + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + + with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): + with pytest.raises(AuthorizationError): + POLICY_SET.guard(depositor, Action.SCHEMA_CREATE) + + # Should have a deny log entry + deny_messages = [r for r in caplog.records if "denied" in r.message.lower()] + assert len(deny_messages) >= 1 + record = deny_messages[0] + assert str(depositor.user_id) in record.message + assert Action.SCHEMA_CREATE in record.message + + def test_guard_logs_deny_for_anonymous(self, caplog: pytest.LogCaptureFixture) -> None: + """Authorization denial for anonymous users should also be logged.""" + with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): + with pytest.raises(AuthorizationError): + POLICY_SET.guard(None, Action.DEPOSITION_CREATE) + + deny_messages = [r for r in caplog.records if "denied" in r.message.lower()] + assert len(deny_messages) >= 1 diff --git a/server/tests/unit/domain/shared/authorization/test_guarded.py b/server/tests/unit/domain/shared/authorization/test_guarded.py new file mode 100644 index 0000000..0e3cac1 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_guarded.py @@ -0,0 +1,64 @@ +"""Tests for Guarded[T]: T009 — generic authorization wrapper.""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.action import Action +from osa.domain.shared.authorization.guarded import Guarded +from osa.domain.shared.authorization.policy_set import POLICY_SET +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + uid = user_id or UserId.generate() + return Principal( + user_id=uid, + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +class _FakeDeposition: + def __init__(self, owner_id: UserId, status: str = "draft") -> None: + self.owner_id = owner_id + self.status = status + + +class TestGuardedCheck: + def test_check_returns_unwrapped_resource_on_success(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + dep = _FakeDeposition(owner_id=user_id) + + guarded = Guarded(dep, principal, POLICY_SET) + result = guarded.check(Action.DEPOSITION_READ) + + assert result is dep + assert result.status == "draft" + + def test_check_raises_on_failure(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + dep = _FakeDeposition(owner_id=UserId.generate()) # different owner + + guarded = Guarded(dep, principal, POLICY_SET) + + with pytest.raises(AuthorizationError): + guarded.check(Action.DEPOSITION_SUBMIT) + + +class TestGuardedNoProxy: + def test_no_attribute_access_proxy(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + dep = _FakeDeposition(owner_id=user_id) + + guarded = Guarded(dep, principal, POLICY_SET) + + # Accessing .status on Guarded[Deposition] should raise AttributeError + with pytest.raises(AttributeError): + _ = guarded.status # type: ignore[attr-defined] diff --git a/server/tests/unit/domain/shared/authorization/test_policy.py b/server/tests/unit/domain/shared/authorization/test_policy.py new file mode 100644 index 0000000..f1f22ca --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_policy.py @@ -0,0 +1,90 @@ +"""Tests for Policy composition: T011 — composable handler-level policies.""" + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.policy import ( + AllOf, + RequiresRole, + requires_any_role, + requires_role, +) + + +def _make_principal(roles: frozenset[Role]) -> Principal: + return Principal( + user_id=UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +class TestRequiresRole: + def test_admin_policy_denies_depositor(self) -> None: + policy = requires_role(Role.ADMIN) + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + + assert policy.evaluate(depositor) is False + + def test_admin_policy_allows_admin(self) -> None: + policy = requires_role(Role.ADMIN) + admin = _make_principal(frozenset({Role.ADMIN})) + + assert policy.evaluate(admin) is True + + def test_admin_policy_allows_superadmin(self) -> None: + policy = requires_role(Role.ADMIN) + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + + assert policy.evaluate(superadmin) is True + + +class TestPolicyComposition: + def test_or_operator(self) -> None: + policy = requires_role(Role.ADMIN) | requires_role(Role.CURATOR) + curator = _make_principal(frozenset({Role.CURATOR})) + + assert policy.evaluate(curator) is True + + def test_and_operator(self) -> None: + # Both must pass — depositor + curator would fail an AllOf(admin, curator) + policy = requires_role(Role.ADMIN) & requires_role(Role.CURATOR) + admin = _make_principal(frozenset({Role.ADMIN})) + + # Admin >= Curator, so both should pass via hierarchy + assert policy.evaluate(admin) is True + + def test_not_operator(self) -> None: + policy = ~requires_role(Role.ADMIN) + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + + assert policy.evaluate(depositor) is True + + def test_not_inverts(self) -> None: + policy = ~requires_role(Role.ADMIN) + admin = _make_principal(frozenset({Role.ADMIN})) + + assert policy.evaluate(admin) is False + + +class TestAllOf: + def test_all_must_pass(self) -> None: + policy = AllOf(policies=(RequiresRole(Role.DEPOSITOR), RequiresRole(Role.CURATOR))) + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + + # Depositor < Curator, so second check fails + assert policy.evaluate(depositor) is False + + +class TestRequiresAnyRole: + def test_any_role_works(self) -> None: + policy = requires_any_role(Role.ADMIN, Role.CURATOR) + curator = _make_principal(frozenset({Role.CURATOR})) + + assert policy.evaluate(curator) is True + + def test_any_role_fails_when_none_match(self) -> None: + policy = requires_any_role(Role.ADMIN, Role.SUPERADMIN) + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + + assert policy.evaluate(depositor) is False diff --git a/server/tests/unit/domain/shared/authorization/test_policy_set.py b/server/tests/unit/domain/shared/authorization/test_policy_set.py new file mode 100644 index 0000000..238b937 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_policy_set.py @@ -0,0 +1,113 @@ +"""Tests for PolicySet: T007 — declarative authorization rules.""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.action import Action +from osa.domain.shared.authorization.policy_set import POLICY_SET +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + uid = user_id or UserId.generate() + return Principal( + user_id=uid, + identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +class _FakeResource: + """Fake resource with owner_id for testing ownership checks.""" + + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestPolicySetOwnership: + def test_owner_can_submit_own_deposition(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + + # Should not raise + POLICY_SET.guard(principal, Action.DEPOSITION_SUBMIT, resource) + + def test_non_owner_denied_submit(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) # different user + + with pytest.raises(AuthorizationError): + POLICY_SET.guard(principal, Action.DEPOSITION_SUBMIT, resource) + + def test_owner_can_read_own_deposition(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + + POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) + + def test_non_owner_depositor_denied_read(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError): + POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) + + +class TestPolicySetRoles: + def test_admin_reads_any_deposition(self) -> None: + principal = _make_principal(frozenset({Role.ADMIN})) + resource = _FakeResource(owner_id=UserId.generate()) + + # Admin >= Curator, so curator read rule (no ownership) should match + POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) + + def test_curator_reads_any_deposition(self) -> None: + principal = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) + + def test_depositor_cannot_create_schema(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + + with pytest.raises(AuthorizationError): + POLICY_SET.guard(principal, Action.SCHEMA_CREATE) + + def test_admin_can_create_schema(self) -> None: + principal = _make_principal(frozenset({Role.ADMIN})) + + POLICY_SET.guard(principal, Action.SCHEMA_CREATE) + + +class TestPolicySetPublic: + def test_public_user_reads_records(self) -> None: + # No principal (anonymous) + POLICY_SET.guard(None, Action.RECORD_READ) + + def test_public_user_can_search(self) -> None: + POLICY_SET.guard(None, Action.SEARCH_QUERY) + + def test_public_user_reads_schemas(self) -> None: + POLICY_SET.guard(None, Action.SCHEMA_READ) + + +class TestPolicySetCoverage: + def test_validate_coverage_passes(self) -> None: + # Should not raise — all actions should be covered + POLICY_SET.validate_coverage() + + def test_validate_coverage_catches_missing(self) -> None: + from osa.domain.shared.authorization.policy_set import PolicySet, allow + from osa.domain.shared.error import ConfigurationError + + # Incomplete policy set + incomplete = PolicySet([allow(Action.RECORD_READ)]) + with pytest.raises(ConfigurationError): + incomplete.validate_coverage() diff --git a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py new file mode 100644 index 0000000..117b695 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py @@ -0,0 +1,60 @@ +"""Tests for Role hierarchy: T012 — numeric hierarchy comparison.""" + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId + + +class TestRoleHierarchy: + def test_admin_ge_curator(self) -> None: + assert Role.ADMIN >= Role.CURATOR + + def test_depositor_lt_admin(self) -> None: + assert not (Role.DEPOSITOR >= Role.ADMIN) + + def test_superadmin_gt_all(self) -> None: + for role in Role: + if role != Role.SUPERADMIN: + assert Role.SUPERADMIN > role + + def test_public_is_lowest(self) -> None: + for role in Role: + assert role >= Role.PUBLIC + + def test_ordering(self) -> None: + assert Role.PUBLIC < Role.DEPOSITOR < Role.CURATOR < Role.ADMIN < Role.SUPERADMIN + + +class TestPrincipalHasRole: + def test_has_role_uses_hierarchy(self) -> None: + principal = Principal( + user_id=UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.ADMIN}), + ) + + # Admin >= Curator, so has_role(CURATOR) should be True + assert principal.has_role(Role.CURATOR) is True + assert principal.has_role(Role.ADMIN) is True + assert principal.has_role(Role.SUPERADMIN) is False + + def test_has_role_depositor(self) -> None: + principal = Principal( + user_id=UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + + assert principal.has_role(Role.DEPOSITOR) is True + assert principal.has_role(Role.CURATOR) is False + assert principal.has_role(Role.ADMIN) is False + + def test_has_any_role(self) -> None: + principal = Principal( + user_id=UserId.generate(), + identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.CURATOR}), + ) + + assert principal.has_any_role(Role.ADMIN, Role.CURATOR) is True + assert principal.has_any_role(Role.SUPERADMIN) is False diff --git a/server/tests/unit/domain/shared/authorization/test_startup_validation.py b/server/tests/unit/domain/shared/authorization/test_startup_validation.py new file mode 100644 index 0000000..1bd738f --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_startup_validation.py @@ -0,0 +1,103 @@ +"""Tests for startup validation of handler __auth__ declarations — T036. + +Tests that all handlers either declare __auth__ or their command/query is __public__. +""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.shared.authorization.policy import requires_role +from osa.domain.auth.model.role import Role +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import ConfigurationError +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +def validate_handlers() -> None: + """Scan all registered handler subclasses for __auth__ declarations. + + Raises ConfigurationError if any handler is missing __auth__ + and its command/query is not __public__. + """ + from osa.domain.shared.authorization.startup import validate_all_handlers + + validate_all_handlers() + + +class TestStartupValidation: + def test_validation_catches_missing_auth_on_command_handler(self) -> None: + """A CommandHandler without __auth__ on a non-public command should fail startup.""" + + class UnprotectedCommand(Command): + pass + + class UnprotectedResult(Result): + pass + + class UnprotectedHandler(CommandHandler[UnprotectedCommand, UnprotectedResult]): + async def run(self, cmd: UnprotectedCommand) -> UnprotectedResult: + return UnprotectedResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="UnprotectedHandler"): + _check_handler_class(UnprotectedHandler, UnprotectedCommand) + + def test_validation_passes_for_protected_handler(self) -> None: + """A handler with __auth__ should pass validation.""" + + class ProtectedCommand(Command): + pass + + class ProtectedResult(Result): + pass + + class ProtectedHandler(CommandHandler[ProtectedCommand, ProtectedResult]): + __auth__ = requires_role(Role.ADMIN) + _principal: Principal | None = None + + async def run(self, cmd: ProtectedCommand) -> ProtectedResult: + return ProtectedResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + # Should not raise + _check_handler_class(ProtectedHandler, ProtectedCommand) + + def test_validation_passes_for_public_command(self) -> None: + """A handler for a __public__ command should pass even without __auth__.""" + from typing import ClassVar + + class PublicCommand(Command): + __public__: ClassVar[bool] = True + + class PublicResult(Result): + pass + + class PublicHandler(CommandHandler[PublicCommand, PublicResult]): + async def run(self, cmd: PublicCommand) -> PublicResult: + return PublicResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + # Should not raise + _check_handler_class(PublicHandler, PublicCommand) + + def test_validation_catches_missing_auth_on_query_handler(self) -> None: + """A QueryHandler without __auth__ on a non-public query should fail.""" + + class UnprotectedQuery(Query): + pass + + class UnprotectedQueryResult(QueryResult): + pass + + class UnprotectedQueryHandler(QueryHandler[UnprotectedQuery, UnprotectedQueryResult]): + async def run(self, cmd: UnprotectedQuery) -> UnprotectedQueryResult: + return UnprotectedQueryResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): + _check_handler_class(UnprotectedQueryHandler, UnprotectedQuery) From 273c73eb69ec5e951a1db65409559bd81b0e79cc Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Feb 2026 02:20:49 +0000 Subject: [PATCH 2/3] feat: redesign authorization with Gate hierarchy and repo decorators Replace the over-engineered PolicySet/Guarded/Action system with two clean layers: handler-level Gate checks (__auth__ = public() or at_least(Role.X)) and resource-level repo decorators (@reads/@writes). - Add Gate base class with Public and AtLeast subclasses - Add Identity hierarchy (Anonymous, System, Principal) - Add @reads/@writes decorators for resource-level checks on repos - Rename OAuth Identity entity to LinkedAccount (avoid name conflict) - Workers inject System() identity via DI context (bypasses all checks) - Delete Action enum, PolicySet, Guarded[T], Policy composables --- server/osa/domain/auth/command/assign_role.py | 10 +- server/osa/domain/auth/command/login.py | 21 ++- server/osa/domain/auth/command/revoke_role.py | 6 +- server/osa/domain/auth/command/token.py | 10 +- server/osa/domain/auth/model/__init__.py | 8 +- server/osa/domain/auth/model/identity.py | 70 +++----- .../osa/domain/auth/model/linked_account.py | 49 ++++++ server/osa/domain/auth/model/principal.py | 6 +- server/osa/domain/auth/port/__init__.py | 4 +- server/osa/domain/auth/port/repository.py | 22 +-- .../osa/domain/auth/query/get_user_roles.py | 6 +- server/osa/domain/auth/service/auth.py | 58 +++---- server/osa/domain/auth/util/di/provider.py | 38 ++-- .../osa/domain/deposition/command/create.py | 6 +- .../domain/deposition/command/delete_files.py | 6 +- .../osa/domain/deposition/command/submit.py | 6 +- .../osa/domain/deposition/command/update.py | 6 +- .../osa/domain/deposition/command/upload.py | 6 +- .../osa/domain/shared/authorization/action.py | 57 ------ .../shared/authorization/authorized_repo.py | 40 ----- .../domain/shared/authorization/decorators.py | 46 +++++ .../osa/domain/shared/authorization/gate.py | 42 +++++ .../domain/shared/authorization/guarded.py | 41 ----- .../osa/domain/shared/authorization/policy.py | 83 --------- .../domain/shared/authorization/policy_set.py | 164 ------------------ .../domain/shared/authorization/resource.py | 111 ++++++++++++ .../domain/shared/authorization/startup.py | 43 +---- server/osa/domain/shared/command.py | 82 ++++----- server/osa/domain/shared/query.py | 67 +++---- server/osa/infrastructure/auth/di.py | 10 +- server/osa/infrastructure/event/worker.py | 13 +- .../persistence/repository/auth.py | 54 +++--- .../persistence/repository/deposition.py | 26 ++- .../unit/domain/auth/test_auth_service.py | 52 +++--- .../unit/domain/auth/test_command_handlers.py | 10 +- .../test_deposition_service_auth.py | 131 -------------- .../shared/authorization/test_auth_gate.py | 36 ++-- .../authorization/test_authorization_audit.py | 63 ------- .../shared/authorization/test_decorators.py | 147 ++++++++++++++++ .../domain/shared/authorization/test_gate.py | 47 +++++ .../shared/authorization/test_guarded.py | 64 ------- .../shared/authorization/test_identity.py | 40 +++++ .../shared/authorization/test_policy.py | 90 ---------- .../shared/authorization/test_policy_set.py | 113 ------------ .../authorization/test_resource_check.py | 139 +++++++++++++++ .../authorization/test_role_hierarchy.py | 6 +- .../authorization/test_startup_validation.py | 64 ++++--- 47 files changed, 993 insertions(+), 1226 deletions(-) create mode 100644 server/osa/domain/auth/model/linked_account.py delete mode 100644 server/osa/domain/shared/authorization/action.py delete mode 100644 server/osa/domain/shared/authorization/authorized_repo.py create mode 100644 server/osa/domain/shared/authorization/decorators.py create mode 100644 server/osa/domain/shared/authorization/gate.py delete mode 100644 server/osa/domain/shared/authorization/guarded.py delete mode 100644 server/osa/domain/shared/authorization/policy.py delete mode 100644 server/osa/domain/shared/authorization/policy_set.py create mode 100644 server/osa/domain/shared/authorization/resource.py delete mode 100644 server/tests/unit/domain/deposition/test_deposition_service_auth.py delete mode 100644 server/tests/unit/domain/shared/authorization/test_authorization_audit.py create mode 100644 server/tests/unit/domain/shared/authorization/test_decorators.py create mode 100644 server/tests/unit/domain/shared/authorization/test_gate.py delete mode 100644 server/tests/unit/domain/shared/authorization/test_guarded.py create mode 100644 server/tests/unit/domain/shared/authorization/test_identity.py delete mode 100644 server/tests/unit/domain/shared/authorization/test_policy.py delete mode 100644 server/tests/unit/domain/shared/authorization/test_policy_set.py create mode 100644 server/tests/unit/domain/shared/authorization/test_resource_check.py diff --git a/server/osa/domain/auth/command/assign_role.py b/server/osa/domain/auth/command/assign_role.py index 7be4b43..725831a 100644 --- a/server/osa/domain/auth/command/assign_role.py +++ b/server/osa/domain/auth/command/assign_role.py @@ -7,7 +7,7 @@ from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import UserId from osa.domain.auth.service.authorization import AuthorizationService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result @@ -29,17 +29,15 @@ class AssignRoleResult(Result): class AssignRoleHandler(CommandHandler[AssignRole, AssignRoleResult]): - __auth__ = requires_role(Role.SUPERADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal authorization_service: AuthorizationService async def run(self, cmd: AssignRole) -> AssignRoleResult: - assert self._principal is not None # Guaranteed by __auth__ gate - assignment = await self.authorization_service.assign_role( user_id=UserId(UUID(cmd.user_id)), role=Role[cmd.role.upper()], - assigned_by=self._principal.user_id, + assigned_by=self.principal.user_id, ) return AssignRoleResult( diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py index de46dc3..ebf2a40 100644 --- a/server/osa/domain/auth/command/login.py +++ b/server/osa/domain/auth/command/login.py @@ -7,8 +7,7 @@ from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService -from typing import ClassVar - +from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventId @@ -18,8 +17,6 @@ class InitiateLogin(Command): """Command to start OAuth login flow.""" - __public__: ClassVar[bool] = True - callback_url: str # OAuth callback URL (where IdP redirects after auth) final_redirect_uri: str # Where to redirect user after OAuth completes provider: str @@ -35,6 +32,8 @@ class InitiateLoginResult(Result): class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]): """Handler for InitiateLogin command.""" + __auth__ = public() + provider_registry: ProviderRegistry token_service: TokenService @@ -63,8 +62,6 @@ async def run(self, cmd: InitiateLogin) -> InitiateLoginResult: class CompleteOAuth(Command): """Command to complete OAuth flow with authorization code.""" - __public__: ClassVar[bool] = True - code: str callback_url: str # Must match the one used in authorization provider: str # The identity provider name (from verified state) @@ -86,6 +83,8 @@ class CompleteOAuthResult(Result): class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]): """Handler for CompleteOAuth command.""" + __auth__ = public() + auth_service: AuthService provider_registry: ProviderRegistry token_service: TokenService @@ -101,7 +100,7 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: code="unknown_provider", ) - user, identity, access_token, refresh_token = await self.auth_service.complete_oauth( + user, linked_account, access_token, refresh_token = await self.auth_service.complete_oauth( provider=identity_provider, code=cmd.code, redirect_uri=cmd.callback_url, @@ -112,16 +111,16 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: UserAuthenticated( id=EventId(uuid4()), user_id=str(user.id), - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, ) ) return CompleteOAuthResult( user_id=str(user.id), display_name=user.display_name, - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, access_token=access_token, refresh_token=refresh_token, expires_in=self.token_service.access_token_expire_seconds, diff --git a/server/osa/domain/auth/command/revoke_role.py b/server/osa/domain/auth/command/revoke_role.py index 4534bd7..b577841 100644 --- a/server/osa/domain/auth/command/revoke_role.py +++ b/server/osa/domain/auth/command/revoke_role.py @@ -6,7 +6,7 @@ from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import UserId from osa.domain.auth.service.authorization import AuthorizationService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result @@ -24,8 +24,8 @@ class RevokeRoleResult(Result): class RevokeRoleHandler(CommandHandler[RevokeRole, RevokeRoleResult]): - __auth__ = requires_role(Role.SUPERADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal authorization_service: AuthorizationService async def run(self, cmd: RevokeRole) -> RevokeRoleResult: diff --git a/server/osa/domain/auth/command/token.py b/server/osa/domain/auth/command/token.py index dc0fd27..08dd2f0 100644 --- a/server/osa/domain/auth/command/token.py +++ b/server/osa/domain/auth/command/token.py @@ -1,12 +1,12 @@ """Token commands for refresh and logout operations.""" from dataclasses import dataclass -from typing import ClassVar from uuid import uuid4 from osa.domain.auth.event import UserLoggedOut from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService +from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.event import EventId from osa.domain.shared.outbox import Outbox @@ -15,8 +15,6 @@ class RefreshTokens(Command): """Command to refresh access token using refresh token.""" - __public__: ClassVar[bool] = True - refresh_token: str @@ -32,6 +30,8 @@ class RefreshTokensResult(Result): class RefreshTokensHandler(CommandHandler[RefreshTokens, RefreshTokensResult]): """Handler for RefreshTokens command.""" + __auth__ = public() + auth_service: AuthService token_service: TokenService @@ -51,8 +51,6 @@ async def run(self, cmd: RefreshTokens) -> RefreshTokensResult: class Logout(Command): """Command to logout and revoke refresh token family.""" - __public__: ClassVar[bool] = True - refresh_token: str @@ -66,6 +64,8 @@ class LogoutResult(Result): class LogoutHandler(CommandHandler[Logout, LogoutResult]): """Handler for Logout command.""" + __auth__ = public() + auth_service: AuthService outbox: Outbox diff --git a/server/osa/domain/auth/model/__init__.py b/server/osa/domain/auth/model/__init__.py index e59c5bf..5850c4f 100644 --- a/server/osa/domain/auth/model/__init__.py +++ b/server/osa/domain/auth/model/__init__.py @@ -1,16 +1,22 @@ """Auth domain models.""" -from .identity import Identity +from .identity import Anonymous, Identity, System +from .linked_account import LinkedAccount +from .principal import Principal from .token import RefreshToken from .user import User from .value import IdentityId, OrcidId, RefreshTokenId, TokenFamilyId, UserId __all__ = [ + "Anonymous", "Identity", "IdentityId", + "LinkedAccount", "OrcidId", + "Principal", "RefreshToken", "RefreshTokenId", + "System", "TokenFamilyId", "User", "UserId", diff --git a/server/osa/domain/auth/model/identity.py b/server/osa/domain/auth/model/identity.py index f86090b..66e0dad 100644 --- a/server/osa/domain/auth/model/identity.py +++ b/server/osa/domain/auth/model/identity.py @@ -1,46 +1,24 @@ -"""Identity entity for the auth domain.""" - -from datetime import UTC, datetime -from typing import Any - -from osa.domain.auth.model.value import IdentityId, UserId -from osa.domain.shared.model.entity import Entity - - -class Identity(Entity): - """A link between a User and an external identity provider. - - Examples: - - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" - - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" - - Invariants: - - `(provider, external_id)` is globally unique - - `user_id` is immutable after creation - - `provider` and `external_id` are immutable after creation - """ - - id: IdentityId - user_id: UserId - provider: str - external_id: str - metadata: dict[str, Any] | None = None # Provider-specific data (name, email) - created_at: datetime - - @classmethod - def create( - cls, - user_id: UserId, - provider: str, - external_id: str, - metadata: dict[str, Any] | None = None, - ) -> "Identity": - """Create a new identity link.""" - return cls( - id=IdentityId.generate(), - user_id=user_id, - provider=provider, - external_id=external_id, - metadata=metadata, - created_at=datetime.now(UTC), - ) +"""Identity hierarchy — base types for all request identities.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Identity: + """Base for all request identities.""" + + pass + + +@dataclass(frozen=True) +class Anonymous(Identity): + """Unauthenticated request.""" + + pass + + +@dataclass(frozen=True) +class System(Identity): + """Internal worker/background process. Bypasses resource checks.""" + + pass diff --git a/server/osa/domain/auth/model/linked_account.py b/server/osa/domain/auth/model/linked_account.py new file mode 100644 index 0000000..a8e6140 --- /dev/null +++ b/server/osa/domain/auth/model/linked_account.py @@ -0,0 +1,49 @@ +"""LinkedAccount entity for the auth domain. + +Links a User to an external identity provider (e.g. ORCiD, SAML). +""" + +from datetime import UTC, datetime +from typing import Any + +from osa.domain.auth.model.value import IdentityId, UserId +from osa.domain.shared.model.entity import Entity + + +class LinkedAccount(Entity): + """A link between a User and an external identity provider. + + Examples: + - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" + - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" + + Invariants: + - `(provider, external_id)` is globally unique + - `user_id` is immutable after creation + - `provider` and `external_id` are immutable after creation + """ + + id: IdentityId + user_id: UserId + provider: str + external_id: str + metadata: dict[str, Any] | None = None # Provider-specific data (name, email) + created_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + provider: str, + external_id: str, + metadata: dict[str, Any] | None = None, + ) -> "LinkedAccount": + """Create a new identity link.""" + return cls( + id=IdentityId.generate(), + user_id=user_id, + provider=provider, + external_id=external_id, + metadata=metadata, + created_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/model/principal.py b/server/osa/domain/auth/model/principal.py index c2bac92..8767bd9 100644 --- a/server/osa/domain/auth/model/principal.py +++ b/server/osa/domain/auth/model/principal.py @@ -2,19 +2,21 @@ from dataclasses import dataclass +from osa.domain.auth.model.identity import Identity from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import ProviderIdentity, UserId @dataclass(frozen=True) -class Principal: +class Principal(Identity): """The authenticated identity of the current requester. Resolved per-request from JWT + role lookup. Immutable after creation. + Subclasses Identity so it can be used wherever Identity is expected. """ user_id: UserId - identity: ProviderIdentity + provider_identity: ProviderIdentity roles: frozenset[Role] def has_role(self, role: Role) -> bool: diff --git a/server/osa/domain/auth/port/__init__.py b/server/osa/domain/auth/port/__init__.py index 9f67041..e51c51c 100644 --- a/server/osa/domain/auth/port/__init__.py +++ b/server/osa/domain/auth/port/__init__.py @@ -1,12 +1,12 @@ """Auth domain ports.""" from .identity_provider import IdentityInfo, IdentityProvider -from .repository import IdentityRepository, RefreshTokenRepository, UserRepository +from .repository import LinkedAccountRepository, RefreshTokenRepository, UserRepository __all__ = [ "IdentityInfo", "IdentityProvider", - "IdentityRepository", + "LinkedAccountRepository", "RefreshTokenRepository", "UserRepository", ] diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py index a2ca4a0..2c1667e 100644 --- a/server/osa/domain/auth/port/repository.py +++ b/server/osa/domain/auth/port/repository.py @@ -3,7 +3,7 @@ from abc import abstractmethod from typing import Protocol -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ( @@ -29,29 +29,29 @@ async def save(self, user: User) -> None: ... -class IdentityRepository(Port, Protocol): - """Repository for Identity entity persistence.""" +class LinkedAccountRepository(Port, Protocol): + """Repository for LinkedAccount entity persistence.""" @abstractmethod - async def get(self, identity_id: IdentityId) -> Identity | None: - """Get an identity by ID.""" + async def get(self, identity_id: IdentityId) -> LinkedAccount | None: + """Get a linked account by ID.""" ... @abstractmethod async def get_by_provider_and_external_id( self, provider: str, external_id: str - ) -> Identity | None: - """Get an identity by provider and external ID.""" + ) -> LinkedAccount | None: + """Get a linked account by provider and external ID.""" ... @abstractmethod - async def get_by_user_id(self, user_id: UserId) -> list[Identity]: - """Get all identities for a user.""" + async def get_by_user_id(self, user_id: UserId) -> list[LinkedAccount]: + """Get all linked accounts for a user.""" ... @abstractmethod - async def save(self, identity: Identity) -> None: - """Save an identity.""" + async def save(self, linked_account: LinkedAccount) -> None: + """Save a linked account.""" ... diff --git a/server/osa/domain/auth/query/get_user_roles.py b/server/osa/domain/auth/query/get_user_roles.py index c51bd21..3b6c85d 100644 --- a/server/osa/domain/auth/query/get_user_roles.py +++ b/server/osa/domain/auth/query/get_user_roles.py @@ -9,7 +9,7 @@ from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import UserId from osa.domain.auth.service.authorization import AuthorizationService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.query import Query, QueryHandler from osa.domain.shared.query import Result as QueryResult @@ -33,8 +33,8 @@ class GetUserRolesResult(QueryResult): class GetUserRolesHandler(QueryHandler[GetUserRoles, GetUserRolesResult]): - __auth__ = requires_role(Role.SUPERADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal authorization_service: AuthorizationService async def run(self, cmd: GetUserRoles) -> GetUserRolesResult: diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py index cf1c25b..75e6d9d 100644 --- a/server/osa/domain/auth/service/auth.py +++ b/server/osa/domain/auth/service/auth.py @@ -2,13 +2,13 @@ import logging -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ProviderIdentity, TokenFamilyId, UserId from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) @@ -29,7 +29,7 @@ class AuthService(Service): """ _user_repo: UserRepository - _identity_repo: IdentityRepository + _linked_account_repo: LinkedAccountRepository _refresh_token_repo: RefreshTokenRepository _token_service: TokenService _outbox: Outbox @@ -57,7 +57,7 @@ async def complete_oauth( provider: IdentityProvider, code: str, redirect_uri: str, - ) -> tuple[User, Identity, str, str]: + ) -> tuple[User, LinkedAccount, str, str]: """Complete OAuth flow and issue tokens. Args: @@ -66,25 +66,25 @@ async def complete_oauth( redirect_uri: Must match the one used in authorization Returns: - Tuple of (user, identity, access_token, refresh_token) + Tuple of (user, linked_account, access_token, refresh_token) """ # Exchange code for identity info identity_info = await provider.exchange_code(code, redirect_uri) - # Find or create user and identity - user, identity = await self._find_or_create_user(identity_info) + # Find or create user and linked account + user, linked_account = await self._find_or_create_user(identity_info) # Create tokens - access_token, refresh_token = await self._create_tokens(user, identity) + access_token, refresh_token = await self._create_tokens(user, linked_account) logger.info( "User authenticated: user_id=%s, provider=%s, external_id=%s", user.id, - identity.provider, - identity.external_id, + linked_account.provider, + linked_account.external_id, ) - return user, identity, access_token, refresh_token + return user, linked_account, access_token, refresh_token async def refresh_tokens( self, @@ -197,10 +197,10 @@ async def get_primary_identity(self, user_id: UserId) -> ProviderIdentity | None this could be extended to support multiple identities with a designated primary. """ - identities = await self._identity_repo.get_by_user_id(user_id) - if not identities: + accounts = await self._linked_account_repo.get_by_user_id(user_id) + if not accounts: return None - first = identities[0] + first = accounts[0] return ProviderIdentity(provider=first.provider, external_id=first.external_id) async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: @@ -216,32 +216,32 @@ async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: stored = await self._refresh_token_repo.get_by_token_hash(token_hash) return stored.user_id if stored else None - async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, Identity]: + async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, LinkedAccount]: """Find existing user by identity or create new one.""" - # Check if identity already exists - existing_identity = await self._identity_repo.get_by_provider_and_external_id( + # Check if linked account already exists + existing = await self._linked_account_repo.get_by_provider_and_external_id( identity_info.provider, identity_info.external_id ) - if existing_identity: + if existing: # User exists, return them - user = await self._user_repo.get(existing_identity.user_id) + user = await self._user_repo.get(existing.user_id) if user is None: - # Orphaned identity - shouldn't happen with CASCADE - raise RuntimeError(f"Identity exists without user: {existing_identity.id}") - return user, existing_identity + # Orphaned linked account - shouldn't happen with CASCADE + raise RuntimeError(f"LinkedAccount exists without user: {existing.id}") + return user, existing - # Create new user and identity + # Create new user and linked account user = User.create(display_name=identity_info.display_name) await self._user_repo.save(user) - identity = Identity.create( + linked_account = LinkedAccount.create( user_id=user.id, provider=identity_info.provider, external_id=identity_info.external_id, metadata=identity_info.raw_data, ) - await self._identity_repo.save(identity) + await self._linked_account_repo.save(linked_account) logger.info( "New user created: user_id=%s, provider=%s", @@ -249,9 +249,9 @@ async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, identity_info.provider, ) - return user, identity + return user, linked_account - async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str]: + async def _create_tokens(self, user: User, linked_account: LinkedAccount) -> tuple[str, str]: """Create access and refresh tokens for a user.""" # Create refresh token raw_token, token_hash = self._token_service.create_refresh_token() @@ -265,8 +265,8 @@ async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str # Create access token provider_identity = ProviderIdentity( - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, ) access_token = self._token_service.create_access_token( user_id=user.id, diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py index 846d746..efa7bf1 100644 --- a/server/osa/domain/auth/util/di/provider.py +++ b/server/osa/domain/auth/util/di/provider.py @@ -16,19 +16,19 @@ ) from osa.domain.auth.command.revoke_role import RevokeRoleHandler from osa.domain.auth.command.token import LogoutHandler, RefreshTokensHandler -from osa.domain.auth.query.get_user_roles import GetUserRolesHandler +from osa.domain.auth.model.identity import Anonymous, Identity from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.domain.auth.query.get_user_roles import GetUserRolesHandler from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.authorization import AuthorizationService from osa.domain.auth.service.token import TokenService -from osa.domain.shared.authorization.policy_set import POLICY_SET, PolicySet from osa.domain.shared.outbox import Outbox from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -64,7 +64,7 @@ def get_token_service(self, config: Config) -> TokenService: def get_auth_service( self, user_repo: UserRepository, - identity_repo: IdentityRepository, + linked_account_repo: LinkedAccountRepository, refresh_token_repo: RefreshTokenRepository, token_service: TokenService, outbox: Outbox, @@ -72,7 +72,7 @@ def get_auth_service( """Provide AuthService.""" return AuthService( _user_repo=user_repo, - _identity_repo=identity_repo, + _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _token_service=token_service, _outbox=outbox, @@ -122,27 +122,26 @@ def get_current_user( ) from e @provide(scope=Scope.UOW) - async def get_principal( + async def get_identity( self, request: Request, token_service: TokenService, role_repo: RoleAssignmentRepository, - ) -> Principal | None: - """Resolve Principal from JWT + role lookup. + ) -> Identity: + """Resolve Identity from JWT + role lookup. - Returns None for anonymous requests (no JWT / invalid JWT). - This allows public endpoints to work without authentication. + Returns Anonymous for unauthenticated requests, Principal for authenticated. """ auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): - return None + return Anonymous() token = auth_header[7:] # Remove "Bearer " prefix try: payload = token_service.validate_access_token(token) except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): - return None + return Anonymous() user_id = UserId(UUID(payload["sub"])) @@ -152,15 +151,18 @@ async def get_principal( return Principal( user_id=user_id, - identity=ProviderIdentity( + provider_identity=ProviderIdentity( provider=payload["provider"], external_id=payload["external_id"], ), roles=roles, ) - @provide(scope=Scope.APP) - def get_policy_set(self) -> PolicySet: - """Provide the global PolicySet singleton. Validates coverage at startup.""" - POLICY_SET.validate_coverage() - return POLICY_SET + @provide(scope=Scope.UOW) + def get_principal(self, identity: Identity) -> Principal: + """Extract Principal from Identity. Raises if not authenticated.""" + from osa.domain.shared.error import AuthorizationError + + if isinstance(identity, Principal): + return identity + raise AuthorizationError("Authentication required", code="missing_token") diff --git a/server/osa/domain/deposition/command/create.py b/server/osa/domain/deposition/command/create.py index 415d67c..c344a88 100644 --- a/server/osa/domain/deposition/command/create.py +++ b/server/osa/domain/deposition/command/create.py @@ -5,7 +5,7 @@ from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -19,8 +19,8 @@ class DepositionCreated(Result): class CreateDepositionHandler(CommandHandler[CreateDeposition, DepositionCreated]): - __auth__ = requires_role(Role.DEPOSITOR) - _principal: Principal | None = None + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService async def run(self, cmd: CreateDeposition) -> DepositionCreated: diff --git a/server/osa/domain/deposition/command/delete_files.py b/server/osa/domain/deposition/command/delete_files.py index 87be419..1f4f0e4 100644 --- a/server/osa/domain/deposition/command/delete_files.py +++ b/server/osa/domain/deposition/command/delete_files.py @@ -3,7 +3,7 @@ from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role from osa.domain.deposition.port import DepositionRepository, StoragePort -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -17,8 +17,8 @@ class DepositionFilesDeleted(Result): class DeleteDepositionFilesHandler(CommandHandler[DeleteDepositionFiles, DepositionFilesDeleted]): - __auth__ = requires_role(Role.DEPOSITOR) - _principal: Principal | None = None + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal repository: DepositionRepository storage: StoragePort diff --git a/server/osa/domain/deposition/command/submit.py b/server/osa/domain/deposition/command/submit.py index 8a9f9d6..88d32c4 100644 --- a/server/osa/domain/deposition/command/submit.py +++ b/server/osa/domain/deposition/command/submit.py @@ -5,7 +5,7 @@ from osa.domain.auth.model.role import Role from osa.domain.deposition.event.submitted import DepositionSubmittedEvent from osa.domain.deposition.service.deposition import DepositionService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.event import EventId from osa.domain.shared.model.srn import DepositionSRN @@ -21,8 +21,8 @@ class DepositionSubmitted(Result): class SubmitDepositionHandler(CommandHandler[SubmitDeposition, DepositionSubmitted]): - __auth__ = requires_role(Role.DEPOSITOR) - _principal: Principal | None = None + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService outbox: Outbox diff --git a/server/osa/domain/deposition/command/update.py b/server/osa/domain/deposition/command/update.py index 7316cf3..d0ee9b8 100644 --- a/server/osa/domain/deposition/command/update.py +++ b/server/osa/domain/deposition/command/update.py @@ -3,7 +3,7 @@ from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result @@ -14,8 +14,8 @@ class DepositionUpdated(Result): ... class UpdateDepositionHandler(CommandHandler[UpdateDeposition, DepositionUpdated]): - __auth__ = requires_role(Role.DEPOSITOR) - _principal: Principal | None = None + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService async def run(self, cmd: UpdateDeposition) -> DepositionUpdated: diff --git a/server/osa/domain/deposition/command/upload.py b/server/osa/domain/deposition/command/upload.py index be29b9e..5e358f8 100644 --- a/server/osa/domain/deposition/command/upload.py +++ b/server/osa/domain/deposition/command/upload.py @@ -4,7 +4,7 @@ from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -20,8 +20,8 @@ class FileUploaded(Result): class UploadFileHandler(CommandHandler[UploadFile, FileUploaded]): - __auth__ = requires_role(Role.DEPOSITOR) - _principal: Principal | None = None + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal async def run(self, cmd: UploadFile) -> FileUploaded: with logfire.span("UploadFile"): diff --git a/server/osa/domain/shared/authorization/action.py b/server/osa/domain/shared/authorization/action.py deleted file mode 100644 index 6b04841..0000000 --- a/server/osa/domain/shared/authorization/action.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Authorization actions — all operations subject to access control.""" - -from enum import StrEnum - - -class Action(StrEnum): - """Structured enum of all authorization-relevant operations.""" - - # Depositions - DEPOSITION_CREATE = "deposition:create" - DEPOSITION_READ = "deposition:read" - DEPOSITION_UPDATE = "deposition:update" - DEPOSITION_SUBMIT = "deposition:submit" - DEPOSITION_DELETE = "deposition:delete" - - # Curation - DEPOSITION_APPROVE = "deposition:approve" - DEPOSITION_REJECT = "deposition:reject" - - # Registry — Schemas - SCHEMA_READ = "schema:read" - SCHEMA_CREATE = "schema:create" - SCHEMA_UPDATE = "schema:update" - SCHEMA_DELETE = "schema:delete" - - # Registry — Traits - TRAIT_READ = "trait:read" - TRAIT_CREATE = "trait:create" - TRAIT_UPDATE = "trait:update" - TRAIT_DELETE = "trait:delete" - - # Registry — Conventions - CONVENTION_READ = "convention:read" - CONVENTION_CREATE = "convention:create" - CONVENTION_UPDATE = "convention:update" - CONVENTION_DELETE = "convention:delete" - - # Registry — Vocabularies - VOCABULARY_READ = "vocabulary:read" - VOCABULARY_CREATE = "vocabulary:create" - VOCABULARY_UPDATE = "vocabulary:update" - VOCABULARY_DELETE = "vocabulary:delete" - - # Records (read-only after publication) - RECORD_READ = "record:read" - - # Search - SEARCH_QUERY = "search:query" - - # Validation - VALIDATION_CREATE = "validation:create" - VALIDATION_READ = "validation:read" - - # Administration - ROLE_ASSIGN = "role:assign" - ROLE_REVOKE = "role:revoke" - ROLE_READ = "role:read" diff --git a/server/osa/domain/shared/authorization/authorized_repo.py b/server/osa/domain/shared/authorization/authorized_repo.py deleted file mode 100644 index 12a1025..0000000 --- a/server/osa/domain/shared/authorization/authorized_repo.py +++ /dev/null @@ -1,40 +0,0 @@ -"""AuthorizedRepo — wraps a raw repository, returns Guarded[T] from get().""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Generic, TypeVar - -from osa.domain.shared.authorization.guarded import Guarded -from osa.domain.shared.error import NotFoundError - -if TYPE_CHECKING: - from osa.domain.auth.model.principal import Principal - from osa.domain.shared.authorization.policy_set import PolicySet - -T = TypeVar("T") -ID = TypeVar("ID") - - -class AuthorizedRepo(Generic[T, ID]): - """Wraps a raw repository and returns Guarded[T] from get(). - - Used by services that need to enforce authorization on loaded resources. - Event handlers and background workers should use the raw repository directly. - """ - - def __init__( - self, - inner: Any, - principal: "Principal", - policy_set: "PolicySet", - ) -> None: - self._inner = inner - self._principal = principal - self._policy_set = policy_set - - async def get(self, id: ID) -> Guarded[T]: - """Load a resource and wrap it in Guarded[T].""" - resource = await self._inner.get(id) - if resource is None: - raise NotFoundError(f"Resource not found: {id}") - return Guarded(resource, self._principal, self._policy_set) diff --git a/server/osa/domain/shared/authorization/decorators.py b/server/osa/domain/shared/authorization/decorators.py new file mode 100644 index 0000000..03f2fa9 --- /dev/null +++ b/server/osa/domain/shared/authorization/decorators.py @@ -0,0 +1,46 @@ +"""Repository method decorators for resource-level authorization. + +@reads(check): After method returns, check the result (skip if None). +@writes(check): Before method runs, check the first resource arg. + +Both decorators access self._identity on the repo instance. +""" + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from osa.domain.shared.authorization.resource import ResourceCheck + + +def reads(check: ResourceCheck) -> Callable: + """After method returns, evaluate the check on the result. + + If the result is None (not found), the check is skipped. + """ + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + result = await fn(self, *args, **kwargs) + if result is not None: + check.evaluate(self._identity, result) + return result + + return wrapper + + return decorator + + +def writes(check: ResourceCheck) -> Callable: + """Before method runs, evaluate the check on the first resource arg.""" + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + async def wrapper(self: Any, resource: Any, *args: Any, **kwargs: Any) -> Any: + check.evaluate(self._identity, resource) + return await fn(self, resource, *args, **kwargs) + + return wrapper + + return decorator diff --git a/server/osa/domain/shared/authorization/gate.py b/server/osa/domain/shared/authorization/gate.py new file mode 100644 index 0000000..3c8c435 --- /dev/null +++ b/server/osa/domain/shared/authorization/gate.py @@ -0,0 +1,42 @@ +"""Handler-level authorization gates: public() and at_least(Role).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from osa.domain.auth.model.role import Role + + +class Gate: + """Base for handler-level authorization gates. + + Every CommandHandler/QueryHandler must declare ``__auth__: ClassVar[Gate]``. + Subclasses define specific gate behaviors (public access, role checks, etc.). + """ + + +@dataclass(frozen=True) +class Public(Gate): + """No authentication required.""" + + +@dataclass(frozen=True) +class AtLeast(Gate): + """Gate that requires the principal to have at least the given role.""" + + role: "Role" + + +_PUBLIC = Public() + + +def public() -> Public: + """Mark a handler as publicly accessible (no auth required).""" + return _PUBLIC + + +def at_least(role: "Role") -> AtLeast: + """Mark a handler as requiring at least the given role.""" + return AtLeast(role=role) diff --git a/server/osa/domain/shared/authorization/guarded.py b/server/osa/domain/shared/authorization/guarded.py deleted file mode 100644 index 06f87d6..0000000 --- a/server/osa/domain/shared/authorization/guarded.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Guarded[T] — generic wrapper forcing explicit authorization check.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Generic, TypeVar - -from osa.domain.shared.authorization.action import Action - -if TYPE_CHECKING: - from osa.domain.auth.model.principal import Principal - from osa.domain.shared.authorization.policy_set import PolicySet - -T = TypeVar("T") - - -class Guarded(Generic[T]): - """Wraps a loaded domain resource, forcing an explicit authorization check. - - The ONLY way to access the inner resource is via `.check(action)`. - No attribute proxy — accessing attributes on Guarded raises AttributeError. - """ - - __slots__ = ("_resource", "_principal", "_policy_set") - - def __init__( - self, - resource: T, - principal: Principal, - policy_set: PolicySet, - ) -> None: - self._resource = resource - self._principal = principal - self._policy_set = policy_set - - def check(self, action: Action) -> T: - """Evaluate authorization and return the unwrapped resource. - - Raises AuthorizationError if access is denied. - """ - self._policy_set.guard(self._principal, action, self._resource) - return self._resource diff --git a/server/osa/domain/shared/authorization/policy.py b/server/osa/domain/shared/authorization/policy.py deleted file mode 100644 index 946e8cd..0000000 --- a/server/osa/domain/shared/authorization/policy.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Composable policy types for handler-level authorization gates.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from osa.domain.auth.model.principal import Principal - from osa.domain.auth.model.role import Role - - -class Policy(ABC): - """Base class for composable authorization policies. - - Policies are evaluated at the handler level as a coarse pre-filter - (role check only, no resource loaded yet). - """ - - @abstractmethod - def evaluate(self, principal: "Principal") -> bool: - """Return True if principal satisfies this policy.""" - ... - - def __and__(self, other: Policy) -> AllOf: - return AllOf(policies=(self, other)) - - def __or__(self, other: Policy) -> AnyOf: - return AnyOf(policies=(self, other)) - - def __invert__(self) -> Not: - return Not(policy=self) - - -@dataclass(frozen=True) -class RequiresRole(Policy): - """Policy that checks principal has at least the given role (hierarchy).""" - - role: "Role" - - def evaluate(self, principal: "Principal") -> bool: - return principal.has_role(self.role) - - -@dataclass(frozen=True) -class AllOf(Policy): - """Policy that requires ALL sub-policies to pass.""" - - policies: tuple[Policy, ...] - - def evaluate(self, principal: "Principal") -> bool: - return all(p.evaluate(principal) for p in self.policies) - - -@dataclass(frozen=True) -class AnyOf(Policy): - """Policy that requires at least ONE sub-policy to pass.""" - - policies: tuple[Policy, ...] - - def evaluate(self, principal: "Principal") -> bool: - return any(p.evaluate(principal) for p in self.policies) - - -@dataclass(frozen=True) -class Not(Policy): - """Policy that inverts another policy.""" - - policy: Policy - - def evaluate(self, principal: "Principal") -> bool: - return not self.policy.evaluate(principal) - - -def requires_role(role: "Role") -> RequiresRole: - """Factory: policy requiring at least the given role.""" - return RequiresRole(role=role) - - -def requires_any_role(*roles: "Role") -> AnyOf: - """Factory: policy requiring at least one of the given roles.""" - return AnyOf(policies=tuple(RequiresRole(role=r) for r in roles)) diff --git a/server/osa/domain/shared/authorization/policy_set.py b/server/osa/domain/shared/authorization/policy_set.py deleted file mode 100644 index bb6e7bc..0000000 --- a/server/osa/domain/shared/authorization/policy_set.py +++ /dev/null @@ -1,164 +0,0 @@ -"""PolicySet — declarative authorization rules and the Relationship enum. - -Contains PolicyRule, Relationship, allow() constructor, and the POLICY_SET constant. -This is the single source of truth for all "who can do what on which resource" rules. -""" - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -if TYPE_CHECKING: - from osa.domain.auth.model.principal import Principal - -from osa.domain.shared.authorization.action import Action - -logger = logging.getLogger(__name__) - - -class Relationship(StrEnum): - """Relationships between a principal and a resource.""" - - OWNER = "owner" - - -@dataclass(frozen=True) -class PolicyRule: - """A single authorization rule in the policy set.""" - - action: Action - role: "Role | None" = None - relationship: Relationship | None = None - - -def allow( - action: Action, - *, - role: "Role | None" = None, - relationship: Relationship | None = None, -) -> PolicyRule: - """Convenience constructor for a policy rule.""" - return PolicyRule(action=action, role=role, relationship=relationship) - - -# Import Role here (after PolicyRule is defined) to avoid circular imports -from osa.domain.auth.model.role import Role # noqa: E402 - - -class PolicySet: - """Declarative set of all authorization rules. - - Evaluation: for a given action, rules are tried in order. - First match wins (allow). No match means deny. - """ - - def __init__(self, rules: list[PolicyRule]) -> None: - self._rules = rules - self._by_action: dict[Action, list[PolicyRule]] = {} - for rule in rules: - self._by_action.setdefault(rule.action, []).append(rule) - - def guard( - self, - principal: "Principal | None", - action: Action, - resource: Any = None, - ) -> None: - """Raise AuthorizationError if no rule allows this access.""" - from osa.domain.shared.error import AuthorizationError - - principal_id = str(principal.user_id) if principal else "anonymous" - - rules = self._by_action.get(action, []) - for rule in rules: - if self._matches(rule, principal, resource): - logger.info( - "Authorization allowed: principal=%s action=%s", - principal_id, - action, - ) - return - - logger.warning( - "Authorization denied: principal=%s action=%s", - principal_id, - action, - ) - raise AuthorizationError(f"Access denied: {action}", code="access_denied") - - def _matches( - self, - rule: PolicyRule, - principal: "Principal | None", - resource: Any, - ) -> bool: - # Public rule (no role required) - if rule.role is None: - return True - # Must be authenticated - if principal is None: - return False - # Role hierarchy check - if not principal.has_role(rule.role): - return False - # Relationship check (if required) - if rule.relationship == Relationship.OWNER: - owner_id = getattr(resource, "owner_id", None) - if owner_id is None or owner_id != principal.user_id: - return False - return True - - def validate_coverage(self) -> None: - """Startup check: every Action enum member must have at least one rule.""" - from osa.domain.shared.error import ConfigurationError - - covered = {r.action for r in self._rules} - missing = set(Action) - covered - if missing: - raise ConfigurationError(f"Actions without policy rules: {missing}") - - -POLICY_SET = PolicySet( - [ - # Public reads (no auth required) - allow(Action.RECORD_READ), - allow(Action.SEARCH_QUERY), - allow(Action.SCHEMA_READ), - allow(Action.TRAIT_READ), - allow(Action.CONVENTION_READ), - allow(Action.VOCABULARY_READ), - allow(Action.VALIDATION_READ), - # Depositions (ownership-scoped) - allow(Action.DEPOSITION_CREATE, role=Role.DEPOSITOR), - allow(Action.DEPOSITION_READ, role=Role.DEPOSITOR, relationship=Relationship.OWNER), - allow(Action.DEPOSITION_UPDATE, role=Role.DEPOSITOR, relationship=Relationship.OWNER), - allow(Action.DEPOSITION_SUBMIT, role=Role.DEPOSITOR, relationship=Relationship.OWNER), - allow(Action.DEPOSITION_DELETE, role=Role.DEPOSITOR, relationship=Relationship.OWNER), - # Curators can read all depositions (no ownership required) - allow(Action.DEPOSITION_READ, role=Role.CURATOR), - allow(Action.DEPOSITION_APPROVE, role=Role.CURATOR), - allow(Action.DEPOSITION_REJECT, role=Role.CURATOR), - # Registry (admin-only writes) - allow(Action.SCHEMA_CREATE, role=Role.ADMIN), - allow(Action.SCHEMA_UPDATE, role=Role.ADMIN), - allow(Action.SCHEMA_DELETE, role=Role.ADMIN), - allow(Action.TRAIT_CREATE, role=Role.ADMIN), - allow(Action.TRAIT_UPDATE, role=Role.ADMIN), - allow(Action.TRAIT_DELETE, role=Role.ADMIN), - allow(Action.CONVENTION_CREATE, role=Role.ADMIN), - allow(Action.CONVENTION_UPDATE, role=Role.ADMIN), - allow(Action.CONVENTION_DELETE, role=Role.ADMIN), - allow(Action.VOCABULARY_CREATE, role=Role.ADMIN), - allow(Action.VOCABULARY_UPDATE, role=Role.ADMIN), - allow(Action.VOCABULARY_DELETE, role=Role.ADMIN), - # Validation - allow(Action.VALIDATION_CREATE, role=Role.DEPOSITOR), - # Administration (superadmin-only) - allow(Action.ROLE_ASSIGN, role=Role.SUPERADMIN), - allow(Action.ROLE_REVOKE, role=Role.SUPERADMIN), - allow(Action.ROLE_READ, role=Role.SUPERADMIN), - ] -) diff --git a/server/osa/domain/shared/authorization/resource.py b/server/osa/domain/shared/authorization/resource.py new file mode 100644 index 0000000..44c867a --- /dev/null +++ b/server/osa/domain/shared/authorization/resource.py @@ -0,0 +1,111 @@ +"""Resource-level authorization checks for repo decorators.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from osa.domain.auth.model.role import Role + + +class ResourceCheck(ABC): + """Base class for resource-level authorization checks. + + System identities bypass all checks. Anonymous identities are rejected. + Principal identities are checked via the abstract _check method. + """ + + def evaluate(self, identity: Any, resource: Any) -> None: + """Evaluate the check against the given identity and resource. + + Raises AuthorizationError if access is denied. + """ + from osa.domain.auth.model.identity import System + from osa.domain.auth.model.principal import Principal + from osa.domain.shared.error import AuthorizationError + + if isinstance(identity, System): + return # Workers bypass all resource checks + + if not isinstance(identity, Principal): + raise AuthorizationError("Authentication required", code="missing_token") + + self._check(identity, resource) + + @abstractmethod + def _check(self, principal: Any, resource: Any) -> None: + """Check authorization for an authenticated principal. + + Args: + principal: The authenticated Principal + resource: The domain resource being accessed + + Raises: + AuthorizationError: If principal is not authorized for this resource. + """ + ... + + def __or__(self, other: ResourceCheck) -> AnyOf: + return AnyOf(checks=(self, other)) + + +@dataclass(frozen=True) +class OwnerCheck(ResourceCheck): + """Check that the principal owns the resource (resource.owner_id == principal.user_id).""" + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + owner_id = getattr(resource, "owner_id", None) + if owner_id is None or owner_id != principal.user_id: + raise AuthorizationError("Access denied: not resource owner", code="access_denied") + + +@dataclass(frozen=True) +class HasRole(ResourceCheck): + """Check that the principal has at least the given role.""" + + role: "Role" + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + if not principal.has_role(self.role): + raise AuthorizationError( + f"Access denied: requires role {self.role.name}", + code="access_denied", + ) + + +@dataclass(frozen=True) +class AnyOf(ResourceCheck): + """Check that at least one of the sub-checks passes.""" + + checks: tuple[ResourceCheck, ...] + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + for check in self.checks: + try: + check._check(principal, resource) + return # At least one passed + except AuthorizationError: + continue + + raise AuthorizationError("Access denied", code="access_denied") + + def __or__(self, other: ResourceCheck) -> AnyOf: + return AnyOf(checks=(*self.checks, other)) + + +def owner() -> OwnerCheck: + """Check that the principal owns the resource.""" + return OwnerCheck() + + +def has_role(role: "Role") -> HasRole: + """Check that the principal has at least the given role.""" + return HasRole(role=role) diff --git a/server/osa/domain/shared/authorization/startup.py b/server/osa/domain/shared/authorization/startup.py index f179aff..ab414fc 100644 --- a/server/osa/domain/shared/authorization/startup.py +++ b/server/osa/domain/shared/authorization/startup.py @@ -2,6 +2,7 @@ import logging +from osa.domain.shared.authorization.gate import Gate from osa.domain.shared.command import CommandHandler from osa.domain.shared.error import ConfigurationError from osa.domain.shared.query import QueryHandler @@ -9,40 +10,14 @@ logger = logging.getLogger(__name__) -def _get_command_or_query_type(handler_cls: type) -> type | None: - """Extract the Command/Query type from a handler's generic bases.""" - from typing import get_args, get_origin - - for base in getattr(handler_cls, "__orig_bases__", []): - origin = get_origin(base) - if origin is None: - continue - name = getattr(origin, "__name__", "") - if name in ("CommandHandler", "QueryHandler"): - args = get_args(base) - if args and isinstance(args[0], type): - return args[0] - return None - - -def _check_handler_class(handler_cls: type, dto_cls: type | None = None) -> None: +def _check_handler_class(handler_cls: type) -> None: """Check a single handler class for __auth__ declaration. - Raises ConfigurationError if the handler lacks __auth__ and its DTO is not __public__. + Every handler must have __auth__ set to a Gate instance. """ - if dto_cls is None: - dto_cls = _get_command_or_query_type(handler_cls) - - # If DTO is public, no __auth__ needed - if dto_cls is not None and getattr(dto_cls, "__public__", False): - return - - # Check for __auth__ - if not hasattr(handler_cls, "__auth__") or getattr(handler_cls, "__auth__") is None: - raise ConfigurationError( - f"Handler {handler_cls.__name__} has no __auth__ declaration " - f"and its command/query is not __public__" - ) + auth = getattr(handler_cls, "__auth__", None) + if not isinstance(auth, Gate): + raise ConfigurationError(f"Handler {handler_cls.__name__} has no __auth__ declaration") def validate_all_handlers() -> None: @@ -53,16 +28,14 @@ def validate_all_handlers() -> None: violations: list[str] = [] for handler_cls in CommandHandler.__subclasses__(): - dto_cls = _get_command_or_query_type(handler_cls) try: - _check_handler_class(handler_cls, dto_cls) + _check_handler_class(handler_cls) except ConfigurationError as e: violations.append(str(e)) for handler_cls in QueryHandler.__subclasses__(): - dto_cls = _get_command_or_query_type(handler_cls) try: - _check_handler_class(handler_cls, dto_cls) + _check_handler_class(handler_cls) except ConfigurationError as e: violations.append(str(e)) diff --git a/server/osa/domain/shared/command.py b/server/osa/domain/shared/command.py index d3ed8e8..f39dee3 100644 --- a/server/osa/domain/shared/command.py +++ b/server/osa/domain/shared/command.py @@ -1,16 +1,20 @@ """Command and CommandHandler base classes with authorization gate.""" +from __future__ import annotations + from abc import ABCMeta, abstractmethod from collections.abc import Callable, Coroutine from dataclasses import dataclass from functools import wraps -from typing import Any, ClassVar, Generic, TypeVar, dataclass_transform +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, dataclass_transform from pydantic import BaseModel +if TYPE_CHECKING: + from osa.domain.shared.authorization.gate import Gate + -class Command(BaseModel): - __public__: ClassVar[bool] = False +class Command(BaseModel): ... class Result(BaseModel): ... @@ -23,55 +27,43 @@ class Result(BaseModel): ... _HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] -def _get_command_type(cls: type) -> type[Command] | None: - """Extract the Command type C from CommandHandler[C, R] in class bases.""" - from typing import get_args, get_origin - - for base in getattr(cls, "__orig_bases__", []): - origin = get_origin(base) - if origin is not None and getattr(origin, "__name__", None) == "CommandHandler": - args = get_args(base) - if args and isinstance(args[0], type) and issubclass(args[0], Command): - return args[0] - return None - - def _wrap_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: - """Wrap the run() method with __auth__ policy evaluation.""" + """Wrap the run() method with __auth__ gate evaluation.""" @wraps(original_run) async def auth_wrapped_run(self: Any, cmd: Any) -> Any: - from osa.domain.shared.error import AuthorizationError + from osa.domain.shared.authorization.gate import AtLeast, Gate, Public + from osa.domain.shared.error import AuthorizationError, ConfigurationError - # Check if the command type is public - cmd_type = type(cmd) - if getattr(cmd_type, "__public__", False): + auth_gate = getattr(type(self), "__auth__", None) + + if not isinstance(auth_gate, Gate): + raise ConfigurationError(f"Handler {type(self).__name__} has no __auth__ declaration") + + if isinstance(auth_gate, Public): return await original_run(self, cmd) - # Non-public: check auth - auth_policy = getattr(type(self), "__auth__", None) - if auth_policy is None: - from osa.domain.shared.error import ConfigurationError + if isinstance(auth_gate, AtLeast): + from osa.domain.auth.model.principal import Principal - raise ConfigurationError( - f"Handler {type(self).__name__} has no __auth__ declaration " - f"and its command is not __public__" - ) + principal = getattr(self, "principal", None) + if not isinstance(principal, Principal): + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) - principal = getattr(self, "_principal", None) - if principal is None: - raise AuthorizationError( - "Authentication required", - code="missing_token", - ) + if not principal.has_role(auth_gate.role): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) - if not auth_policy.evaluate(principal): - raise AuthorizationError( - f"Access denied: insufficient role for {type(self).__name__}", - code="access_denied", - ) + return await original_run(self, cmd) - return await original_run(self, cmd) + raise ConfigurationError( # pragma: no cover — future gate types handled here + f"Handler {type(self).__name__} has unhandled __auth__ type: {type(auth_gate).__name__}" + ) return auth_wrapped_run @@ -85,7 +77,7 @@ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]): if any(isinstance(b, mcs) for b in bases): cls = dataclass(cls) - # Wrap run() with auth gate if __auth__ is declared or command is not public + # Wrap run() with auth gate original_run = cls.__dict__.get("run") if original_run is not None: wrapped = _wrap_run_with_auth(cls, original_run) @@ -99,9 +91,11 @@ class CommandHandler(Generic[C, R], metaclass=_CommandHandlerMeta): Declare __auth__ to enforce role-based access: class MyHandler(CommandHandler[MyCmd, MyResult]): - __auth__ = requires_role(Role.ADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.ADMIN) + principal: Principal """ + __auth__: ClassVar[Gate] + @abstractmethod async def run(self, cmd: C) -> R: ... diff --git a/server/osa/domain/shared/query.py b/server/osa/domain/shared/query.py index d4c800a..fbf8c90 100644 --- a/server/osa/domain/shared/query.py +++ b/server/osa/domain/shared/query.py @@ -1,16 +1,20 @@ """Query and QueryHandler base classes with authorization gate.""" +from __future__ import annotations + from abc import ABCMeta, abstractmethod from collections.abc import Callable, Coroutine from dataclasses import dataclass from functools import wraps -from typing import Any, ClassVar, Generic, TypeVar, dataclass_transform +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, dataclass_transform from pydantic import BaseModel +if TYPE_CHECKING: + from osa.domain.shared.authorization.gate import Gate + -class Query(BaseModel): - __public__: ClassVar[bool] = False +class Query(BaseModel): ... class Result(BaseModel): ... @@ -24,41 +28,42 @@ class Result(BaseModel): ... def _wrap_query_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: - """Wrap the run() method with __auth__ policy evaluation.""" + """Wrap the run() method with __auth__ gate evaluation.""" @wraps(original_run) async def auth_wrapped_run(self: Any, cmd: Any) -> Any: - from osa.domain.shared.error import AuthorizationError + from osa.domain.shared.authorization.gate import AtLeast, Gate, Public + from osa.domain.shared.error import AuthorizationError, ConfigurationError + + auth_gate = getattr(type(self), "__auth__", None) - # Check if the query type is public - cmd_type = type(cmd) - if getattr(cmd_type, "__public__", False): + if not isinstance(auth_gate, Gate): + raise ConfigurationError(f"Handler {type(self).__name__} has no __auth__ declaration") + + if isinstance(auth_gate, Public): return await original_run(self, cmd) - # Non-public: check auth - auth_policy = getattr(type(self), "__auth__", None) - if auth_policy is None: - from osa.domain.shared.error import ConfigurationError + if isinstance(auth_gate, AtLeast): + from osa.domain.auth.model.principal import Principal - raise ConfigurationError( - f"Handler {type(self).__name__} has no __auth__ declaration " - f"and its query is not __public__" - ) + principal = getattr(self, "principal", None) + if not isinstance(principal, Principal): + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) - principal = getattr(self, "_principal", None) - if principal is None: - raise AuthorizationError( - "Authentication required", - code="missing_token", - ) + if not principal.has_role(auth_gate.role): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) - if not auth_policy.evaluate(principal): - raise AuthorizationError( - f"Access denied: insufficient role for {type(self).__name__}", - code="access_denied", - ) + return await original_run(self, cmd) - return await original_run(self, cmd) + raise ConfigurationError( # pragma: no cover — future gate types handled here + f"Handler {type(self).__name__} has unhandled __auth__ type: {type(auth_gate).__name__}" + ) return auth_wrapped_run @@ -86,9 +91,11 @@ class QueryHandler(Generic[C, R], metaclass=_QueryHandlerMeta): Declare __auth__ to enforce role-based access: class MyHandler(QueryHandler[MyQuery, MyResult]): - __auth__ = requires_role(Role.ADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.ADMIN) + principal: Principal """ + __auth__: ClassVar[Gate] + @abstractmethod async def run(self, cmd: C) -> R: ... diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py index 8b43724..1fadacd 100644 --- a/server/osa/infrastructure/auth/di.py +++ b/server/osa/infrastructure/auth/di.py @@ -7,7 +7,7 @@ from osa.domain.auth.port.identity_provider import IdentityProvider from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) @@ -16,7 +16,7 @@ from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry from osa.infrastructure.auth.role_repository import PostgresRoleAssignmentRepository from osa.infrastructure.persistence.repository.auth import ( - PostgresIdentityRepository, + PostgresLinkedAccountRepository, PostgresRefreshTokenRepository, PostgresUserRepository, ) @@ -41,10 +41,10 @@ class AuthInfraProvider(Provider): scope=Scope.UOW, provides=UserRepository, ) - identity_repo = provide( - PostgresIdentityRepository, + linked_account_repo = provide( + PostgresLinkedAccountRepository, scope=Scope.UOW, - provides=IdentityRepository, + provides=LinkedAccountRepository, ) refresh_token_repo = provide( PostgresRefreshTokenRepository, diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index fb5e758..858f06f 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from osa.application.event import ServerStarted +from osa.domain.auth.model.identity import Identity, System from osa.domain.shared.error import SkippedEvents from osa.domain.shared.event import ( EventHandler, @@ -183,8 +184,8 @@ async def _poll_once(self) -> bool: self._state.status = WorkerStatus.CLAIMING - # Claim and process within a UOW scope - async with self._container(scope=Scope.UOW) as scope: + # Claim and process within a UOW scope (System identity for workers) + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: outbox = await scope.get(Outbox) session = await scope.get(AsyncSession) @@ -393,7 +394,7 @@ async def _emit_server_started(self) -> None: if self._container is None: return - async with self._container(scope=Scope.UOW) as scope: + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: outbox = await scope.get(Outbox) await outbox.append(ServerStarted(id=EventId(uuid4()))) session = await scope.get(AsyncSession) @@ -440,7 +441,7 @@ async def _run_schedule(self, config: "ScheduleConfig") -> None: return try: - async with self._container(scope=Scope.UOW) as scope: + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: schedule = await scope.get(config.schedule_type) await schedule.run(**config.params) session = await scope.get(AsyncSession) @@ -484,7 +485,9 @@ async def _run_stale_claim_cleanup(self) -> None: max_timeout = max(w.config.claim_timeout for w in self._workers) # Use a scoped outbox for cleanup - async with self._container(scope=Scope.UOW) as scope: + async with self._container( + scope=Scope.UOW, context={Identity: System()} + ) as scope: outbox = await scope.get(Outbox) session = await scope.get(AsyncSession) count = await outbox.reset_stale_claims(max_timeout) diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py index ff46522..515e713 100644 --- a/server/osa/infrastructure/persistence/repository/auth.py +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -6,7 +6,7 @@ from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ( @@ -16,7 +16,7 @@ UserId, ) from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) @@ -47,9 +47,9 @@ def _user_to_dict(user: User) -> dict: } -def _row_to_identity(row: dict) -> Identity: - """Convert a database row to an Identity model.""" - return Identity( +def _row_to_linked_account(row: dict) -> LinkedAccount: + """Convert a database row to a LinkedAccount model.""" + return LinkedAccount( id=IdentityId(UUID(row["id"])), user_id=UserId(UUID(row["user_id"])), provider=row["provider"], @@ -59,15 +59,15 @@ def _row_to_identity(row: dict) -> Identity: ) -def _identity_to_dict(identity: Identity) -> dict: - """Convert an Identity model to a database row dict.""" +def _linked_account_to_dict(account: LinkedAccount) -> dict: + """Convert a LinkedAccount model to a database row dict.""" return { - "id": str(identity.id), - "user_id": str(identity.user_id), - "provider": identity.provider, - "external_id": identity.external_id, - "metadata": identity.metadata, - "created_at": identity.created_at, + "id": str(account.id), + "user_id": str(account.user_id), + "provider": account.provider, + "external_id": account.external_id, + "metadata": account.metadata, + "created_at": account.created_at, } @@ -122,47 +122,47 @@ async def save(self, user: User) -> None: await self.session.flush() -class PostgresIdentityRepository(IdentityRepository): - """PostgreSQL implementation of IdentityRepository.""" +class PostgresLinkedAccountRepository(LinkedAccountRepository): + """PostgreSQL implementation of LinkedAccountRepository.""" def __init__(self, session: AsyncSession) -> None: self.session = session - async def get(self, identity_id: IdentityId) -> Identity | None: + async def get(self, identity_id: IdentityId) -> LinkedAccount | None: stmt = select(identities_table).where(identities_table.c.id == str(identity_id)) result = await self.session.execute(stmt) row = result.mappings().first() - return _row_to_identity(dict(row)) if row else None + return _row_to_linked_account(dict(row)) if row else None async def get_by_provider_and_external_id( self, provider: str, external_id: str - ) -> Identity | None: + ) -> LinkedAccount | None: stmt = select(identities_table).where( identities_table.c.provider == provider, identities_table.c.external_id == external_id, ) result = await self.session.execute(stmt) row = result.mappings().first() - return _row_to_identity(dict(row)) if row else None + return _row_to_linked_account(dict(row)) if row else None - async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + async def get_by_user_id(self, user_id: UserId) -> list[LinkedAccount]: stmt = select(identities_table).where(identities_table.c.user_id == str(user_id)) result = await self.session.execute(stmt) rows = result.mappings().all() - return [_row_to_identity(dict(row)) for row in rows] + return [_row_to_linked_account(dict(row)) for row in rows] - async def save(self, identity: Identity) -> None: - identity_dict = _identity_to_dict(identity) - existing = await self.get(identity.id) + async def save(self, linked_account: LinkedAccount) -> None: + account_dict = _linked_account_to_dict(linked_account) + existing = await self.get(linked_account.id) if existing: stmt = ( update(identities_table) - .where(identities_table.c.id == str(identity.id)) - .values(**identity_dict) + .where(identities_table.c.id == str(linked_account.id)) + .values(**account_dict) ) else: - stmt = insert(identities_table).values(**identity_dict) + stmt = insert(identities_table).values(**account_dict) await self.session.execute(stmt) await self.session.flush() diff --git a/server/osa/infrastructure/persistence/repository/deposition.py b/server/osa/infrastructure/persistence/repository/deposition.py index 40c5ef5..5d585fa 100644 --- a/server/osa/infrastructure/persistence/repository/deposition.py +++ b/server/osa/infrastructure/persistence/repository/deposition.py @@ -1,8 +1,12 @@ from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.role import Role from osa.domain.deposition.model.aggregate import Deposition from osa.domain.deposition.port.repository import DepositionRepository +from osa.domain.shared.authorization.decorators import reads, writes +from osa.domain.shared.authorization.resource import has_role, owner from osa.domain.shared.model.srn import DepositionSRN from osa.infrastructure.persistence.mappers.deposition import ( row_to_deposition, @@ -14,13 +18,25 @@ class PostgresDepositionRepository(DepositionRepository): """PostgreSQL implementation of DepositionRepository.""" - def __init__(self, session: AsyncSession) -> None: + def __init__(self, session: AsyncSession, identity: Identity) -> None: self.session = session + self._identity = identity + @reads(owner() | has_role(Role.CURATOR)) + async def get(self, srn: DepositionSRN) -> Deposition | None: + stmt = select(depositions_table).where(depositions_table.c.srn == str(srn)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return row_to_deposition(dict(row)) if row else None + + @writes(owner()) async def save(self, deposition: Deposition) -> None: dep_dict = deposition_to_dict(deposition) - existing = await self.get(deposition.srn) + # Check if exists (bypass decorator for internal lookup) + stmt = select(depositions_table).where(depositions_table.c.srn == str(deposition.srn)) + result = await self.session.execute(stmt) + existing = result.mappings().first() if existing: stmt = ( @@ -33,9 +49,3 @@ async def save(self, deposition: Deposition) -> None: await self.session.execute(stmt) await self.session.flush() - - async def get(self, srn: DepositionSRN) -> Deposition | None: - stmt = select(depositions_table).where(depositions_table.c.srn == str(srn)) - result = await self.session.execute(stmt) - row = result.mappings().first() - return row_to_deposition(dict(row)) if row else None diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py index bef7f98..28e6fd4 100644 --- a/server/tests/unit/domain/auth/test_auth_service.py +++ b/server/tests/unit/domain/auth/test_auth_service.py @@ -7,7 +7,7 @@ import pytest from osa.config import JwtConfig -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import IdentityId, RefreshTokenId, TokenFamilyId, UserId @@ -19,7 +19,7 @@ def make_auth_service( user_repo: AsyncMock | None = None, - identity_repo: AsyncMock | None = None, + linked_account_repo: AsyncMock | None = None, refresh_token_repo: AsyncMock | None = None, token_service: TokenService | None = None, outbox: AsyncMock | None = None, @@ -27,8 +27,8 @@ def make_auth_service( """Create an AuthService with mocked dependencies.""" if user_repo is None: user_repo = AsyncMock() - if identity_repo is None: - identity_repo = AsyncMock() + if linked_account_repo is None: + linked_account_repo = AsyncMock() if refresh_token_repo is None: refresh_token_repo = AsyncMock() if token_service is None: @@ -44,7 +44,7 @@ def make_auth_service( return AuthService( _user_repo=user_repo, - _identity_repo=identity_repo, + _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _token_service=token_service, _outbox=outbox, @@ -101,14 +101,16 @@ async def test_complete_oauth_creates_new_user(self): user_repo = AsyncMock() user_repo.get.return_value = None # No existing user - identity_repo = AsyncMock() - identity_repo.get_by_provider_and_external_id.return_value = None # No existing identity + linked_account_repo = AsyncMock() + linked_account_repo.get_by_provider_and_external_id.return_value = ( + None # No existing identity + ) refresh_token_repo = AsyncMock() service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) provider = make_identity_provider() @@ -119,9 +121,9 @@ async def test_complete_oauth_creates_new_user(self): redirect_uri="http://localhost/callback", ) - # Should create user and identity + # Should create user and linked account user_repo.save.assert_called_once() - identity_repo.save.assert_called_once() + linked_account_repo.save.assert_called_once() refresh_token_repo.save.assert_called_once() # Should return valid data @@ -140,7 +142,7 @@ async def test_complete_oauth_returns_existing_user(self): created_at=datetime.now(UTC), updated_at=None, ) - existing_identity = Identity( + existing_linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=existing_user.id, provider="orcid", @@ -152,14 +154,14 @@ async def test_complete_oauth_returns_existing_user(self): user_repo = AsyncMock() user_repo.get.return_value = existing_user - identity_repo = AsyncMock() - identity_repo.get_by_provider_and_external_id.return_value = existing_identity + linked_account_repo = AsyncMock() + linked_account_repo.get_by_provider_and_external_id.return_value = existing_linked_account refresh_token_repo = AsyncMock() service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) provider = make_identity_provider() @@ -170,13 +172,13 @@ async def test_complete_oauth_returns_existing_user(self): redirect_uri="http://localhost/callback", ) - # Should NOT create new user/identity + # Should NOT create new user/linked account user_repo.save.assert_not_called() - identity_repo.save.assert_not_called() + linked_account_repo.save.assert_not_called() # Should return existing user assert user.id == existing_user.id - assert identity.id == existing_identity.id + assert identity.id == existing_linked_account.id class TestAuthServiceRefreshTokens: @@ -191,7 +193,7 @@ async def test_refresh_tokens_issues_new_tokens(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -212,15 +214,15 @@ async def test_refresh_tokens_issues_new_tokens(self): user_repo = AsyncMock() user_repo.get.return_value = user - identity_repo = AsyncMock() - identity_repo.get_by_user_id.return_value = [identity] + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] refresh_token_repo = AsyncMock() refresh_token_repo.get_by_token_hash.return_value = old_token service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) @@ -245,7 +247,7 @@ async def test_refresh_tokens_revokes_old_token(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -266,15 +268,15 @@ async def test_refresh_tokens_revokes_old_token(self): user_repo = AsyncMock() user_repo.get.return_value = user - identity_repo = AsyncMock() - identity_repo.get_by_user_id.return_value = [identity] + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] refresh_token_repo = AsyncMock() refresh_token_repo.get_by_token_hash.return_value = old_token service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) diff --git a/server/tests/unit/domain/auth/test_command_handlers.py b/server/tests/unit/domain/auth/test_command_handlers.py index 9179e15..0dc60b0 100644 --- a/server/tests/unit/domain/auth/test_command_handlers.py +++ b/server/tests/unit/domain/auth/test_command_handlers.py @@ -20,7 +20,7 @@ RefreshTokensHandler, ) from osa.domain.auth.event import UserAuthenticated, UserLoggedOut -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.user import User from osa.domain.auth.model.value import IdentityId, UserId from osa.domain.auth.service.token import TokenService @@ -128,7 +128,7 @@ async def test_run_emits_user_authenticated_event(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -140,7 +140,7 @@ async def test_run_emits_user_authenticated_event(self): auth_service = AsyncMock() auth_service.complete_oauth.return_value = ( user, - identity, + linked_account, "access-token", "refresh-token", ) @@ -181,7 +181,7 @@ async def test_run_returns_user_info_and_tokens(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -193,7 +193,7 @@ async def test_run_returns_user_info_and_tokens(self): auth_service = AsyncMock() auth_service.complete_oauth.return_value = ( user, - identity, + linked_account, "access-token", "refresh-token", ) diff --git a/server/tests/unit/domain/deposition/test_deposition_service_auth.py b/server/tests/unit/domain/deposition/test_deposition_service_auth.py deleted file mode 100644 index 355b751..0000000 --- a/server/tests/unit/domain/deposition/test_deposition_service_auth.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Tests for DepositionService authorization — T029. - -Tests that: -- Owner can read/update/submit their own deposition via Guarded[Deposition] -- Non-owner depositor is denied access to another's deposition -- owner_id is set from principal at creation time -""" - -import pytest - -from osa.domain.auth.model.principal import Principal -from osa.domain.auth.model.role import Role -from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.deposition.model.aggregate import Deposition -from osa.domain.deposition.model.value import DepositionStatus -from osa.domain.shared.authorization.action import Action -from osa.domain.shared.authorization.guarded import Guarded -from osa.domain.shared.authorization.policy_set import POLICY_SET -from osa.domain.shared.error import AuthorizationError -from osa.domain.shared.model.srn import DepositionSRN - - -def _make_principal( - user_id: UserId | None = None, roles: frozenset[Role] | None = None -) -> Principal: - return Principal( - user_id=user_id or UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="test-ext"), - roles=roles or frozenset({Role.DEPOSITOR}), - ) - - -def _make_deposition(owner_id: UserId) -> Deposition: - return Deposition( - srn=DepositionSRN.parse("urn:osa:localhost:dep:00000000-0000-0000-0000-000000000001"), - status=DepositionStatus.DRAFT, - metadata={}, - owner_id=owner_id, - ) - - -class TestDepositionOwnership: - """Guarded[Deposition] enforces ownership rules via POLICY_SET.""" - - def test_owner_can_read_own_deposition(self) -> None: - owner = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, owner, POLICY_SET) - - result = guarded.check(Action.DEPOSITION_READ) - assert result is dep - - def test_owner_can_update_own_deposition(self) -> None: - owner = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, owner, POLICY_SET) - - result = guarded.check(Action.DEPOSITION_UPDATE) - assert result is dep - - def test_owner_can_submit_own_deposition(self) -> None: - owner = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, owner, POLICY_SET) - - result = guarded.check(Action.DEPOSITION_SUBMIT) - assert result is dep - - def test_non_owner_depositor_cannot_read_others_deposition(self) -> None: - owner = _make_principal() - other = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, other, POLICY_SET) - - with pytest.raises(AuthorizationError): - guarded.check(Action.DEPOSITION_READ) - - def test_non_owner_depositor_cannot_update_others_deposition(self) -> None: - owner = _make_principal() - other = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, other, POLICY_SET) - - with pytest.raises(AuthorizationError): - guarded.check(Action.DEPOSITION_UPDATE) - - def test_non_owner_depositor_cannot_submit_others_deposition(self) -> None: - owner = _make_principal() - other = _make_principal() - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, other, POLICY_SET) - - with pytest.raises(AuthorizationError): - guarded.check(Action.DEPOSITION_SUBMIT) - - def test_curator_can_read_any_deposition(self) -> None: - """Curators can read all depositions without ownership.""" - owner = _make_principal() - curator = _make_principal(roles=frozenset({Role.CURATOR})) - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, curator, POLICY_SET) - - result = guarded.check(Action.DEPOSITION_READ) - assert result is dep - - def test_admin_can_read_any_deposition(self) -> None: - """Admins inherit curator permissions via role hierarchy.""" - owner = _make_principal() - admin = _make_principal(roles=frozenset({Role.ADMIN})) - dep = _make_deposition(owner_id=owner.user_id) - guarded = Guarded(dep, admin, POLICY_SET) - - result = guarded.check(Action.DEPOSITION_READ) - assert result is dep - - -class TestDepositionOwnerIdAssignment: - """Deposition aggregate tracks owner_id.""" - - def test_deposition_has_owner_id_field(self) -> None: - owner_id = UserId.generate() - dep = _make_deposition(owner_id=owner_id) - assert dep.owner_id == owner_id - - def test_deposition_owner_id_defaults_to_none(self) -> None: - dep = Deposition( - srn=DepositionSRN.parse("urn:osa:localhost:dep:00000000-0000-0000-0000-000000000001"), - status=DepositionStatus.DRAFT, - metadata={}, - ) - assert dep.owner_id is None diff --git a/server/tests/unit/domain/shared/authorization/test_auth_gate.py b/server/tests/unit/domain/shared/authorization/test_auth_gate.py index cdd1763..6f33f15 100644 --- a/server/tests/unit/domain/shared/authorization/test_auth_gate.py +++ b/server/tests/unit/domain/shared/authorization/test_auth_gate.py @@ -1,11 +1,11 @@ -"""Tests for handler __auth__ gate: T013 — metaclass wraps run() with auth check.""" +"""Tests for handler __auth__ gate: metaclass wraps run() with auth check.""" import pytest from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.shared.authorization.policy import requires_role +from osa.domain.shared.authorization.gate import at_least, public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.error import AuthorizationError, ConfigurationError from osa.domain.shared.query import Query, QueryHandler @@ -15,7 +15,7 @@ def _make_principal(roles: frozenset[Role]) -> Principal: return Principal( user_id=UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="test-ext"), + provider_identity=ProviderIdentity(provider="test", external_id="test-ext"), roles=roles, ) @@ -32,7 +32,6 @@ class AdminOnlyResult(Result): class PublicCommand(Command): - __public__: bool = True # ClassVar-like, signals public access value: str = "test" @@ -44,15 +43,15 @@ class PublicResult(Result): class AdminOnlyHandler(CommandHandler[AdminOnlyCommand, AdminOnlyResult]): - __auth__ = requires_role(Role.ADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.ADMIN) + principal: Principal async def run(self, cmd: AdminOnlyCommand) -> AdminOnlyResult: return AdminOnlyResult(value=cmd.value) class PublicHandler(CommandHandler[PublicCommand, PublicResult]): - _principal: Principal | None = None + __auth__ = public() async def run(self, cmd: PublicCommand) -> PublicResult: return PublicResult(value=cmd.value) @@ -91,7 +90,7 @@ class TestAuthGateOnCommandHandler: @pytest.mark.asyncio async def test_admin_handler_rejects_depositor(self) -> None: depositor = _make_principal(frozenset({Role.DEPOSITOR})) - handler = AdminOnlyHandler(_principal=depositor) + handler = AdminOnlyHandler(principal=depositor) with pytest.raises(AuthorizationError): await handler.run(AdminOnlyCommand(value="test")) @@ -99,29 +98,38 @@ async def test_admin_handler_rejects_depositor(self) -> None: @pytest.mark.asyncio async def test_admin_handler_allows_admin(self) -> None: admin = _make_principal(frozenset({Role.ADMIN})) - handler = AdminOnlyHandler(_principal=admin) + handler = AdminOnlyHandler(principal=admin) result = await handler.run(AdminOnlyCommand(value="hello")) assert result.value == "hello" @pytest.mark.asyncio - async def test_admin_handler_rejects_none_principal(self) -> None: - handler = AdminOnlyHandler(_principal=None) + async def test_admin_handler_allows_superadmin(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + handler = AdminOnlyHandler(principal=superadmin) + + result = await handler.run(AdminOnlyCommand(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_admin_handler_rejects_missing_principal(self) -> None: + # Principal field not provided — should raise AuthorizationError + handler = AdminOnlyHandler.__new__(AdminOnlyHandler) with pytest.raises(AuthorizationError): await handler.run(AdminOnlyCommand(value="test")) @pytest.mark.asyncio async def test_public_handler_skips_check(self) -> None: - handler = PublicHandler(_principal=None) + handler = PublicHandler() result = await handler.run(PublicCommand(value="public")) assert result.value == "public" @pytest.mark.asyncio async def test_public_handler_works_with_principal(self) -> None: - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - handler = PublicHandler(_principal=depositor) + # Public handlers work regardless of principal presence + handler = PublicHandler() result = await handler.run(PublicCommand(value="public")) assert result.value == "public" diff --git a/server/tests/unit/domain/shared/authorization/test_authorization_audit.py b/server/tests/unit/domain/shared/authorization/test_authorization_audit.py deleted file mode 100644 index afb00cd..0000000 --- a/server/tests/unit/domain/shared/authorization/test_authorization_audit.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Tests for authorization audit logging — T058. - -Tests that PolicySet.guard() emits structured log entries for allow and deny decisions. -""" - -import logging - -import pytest - -from osa.domain.auth.model.principal import Principal -from osa.domain.auth.model.role import Role -from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.shared.authorization.action import Action -from osa.domain.shared.authorization.policy_set import POLICY_SET -from osa.domain.shared.error import AuthorizationError - - -def _make_principal(roles: frozenset[Role], user_id: UserId | None = None) -> Principal: - return Principal( - user_id=user_id or UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="test-ext"), - roles=roles, - ) - - -class TestAuthorizationAuditLogging: - def test_guard_logs_allow(self, caplog: pytest.LogCaptureFixture) -> None: - """Successful authorization should emit an info-level log.""" - principal = _make_principal(frozenset({Role.ADMIN})) - - with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): - POLICY_SET.guard(principal, Action.SCHEMA_CREATE) - - # Should have an allow log entry - allow_messages = [r for r in caplog.records if "allowed" in r.message.lower()] - assert len(allow_messages) >= 1 - record = allow_messages[0] - assert str(principal.user_id) in record.message - assert Action.SCHEMA_CREATE in record.message - - def test_guard_logs_deny(self, caplog: pytest.LogCaptureFixture) -> None: - """Failed authorization should emit a warning-level log.""" - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - - with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): - with pytest.raises(AuthorizationError): - POLICY_SET.guard(depositor, Action.SCHEMA_CREATE) - - # Should have a deny log entry - deny_messages = [r for r in caplog.records if "denied" in r.message.lower()] - assert len(deny_messages) >= 1 - record = deny_messages[0] - assert str(depositor.user_id) in record.message - assert Action.SCHEMA_CREATE in record.message - - def test_guard_logs_deny_for_anonymous(self, caplog: pytest.LogCaptureFixture) -> None: - """Authorization denial for anonymous users should also be logged.""" - with caplog.at_level(logging.DEBUG, logger="osa.domain.shared.authorization.policy_set"): - with pytest.raises(AuthorizationError): - POLICY_SET.guard(None, Action.DEPOSITION_CREATE) - - deny_messages = [r for r in caplog.records if "denied" in r.message.lower()] - assert len(deny_messages) >= 1 diff --git a/server/tests/unit/domain/shared/authorization/test_decorators.py b/server/tests/unit/domain/shared/authorization/test_decorators.py new file mode 100644 index 0000000..dd4e9a1 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_decorators.py @@ -0,0 +1,147 @@ +"""Tests for @reads/@writes repo decorators.""" + +import pytest + +from osa.domain.auth.model.identity import Anonymous, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.decorators import reads, writes +from osa.domain.shared.authorization.resource import has_role, owner +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class _FakeRepo: + """Minimal repo with _identity and decorated methods.""" + + def __init__(self, identity, resource=None): + self._identity = identity + self._resource = resource + + @reads(owner() | has_role(Role.CURATOR)) + async def get(self, key: str): + return self._resource + + @writes(owner()) + async def save(self, resource) -> None: + self._resource = resource + + +class TestReadsDecorator: + @pytest.mark.asyncio + async def test_reads_allows_owner(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + repo = _FakeRepo(identity=principal, resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_allows_curator(self) -> None: + curator = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=curator, resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_denies_non_owner_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=depositor, resource=resource) + + with pytest.raises(AuthorizationError): + await repo.get("key") + + @pytest.mark.asyncio + async def test_reads_skips_check_when_none(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + repo = _FakeRepo(identity=depositor, resource=None) + + result = await repo.get("key") + assert result is None + + @pytest.mark.asyncio + async def test_reads_allows_system(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=System(), resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_denies_anonymous(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous(), resource=resource) + + with pytest.raises(AuthorizationError, match="Authentication required"): + await repo.get("key") + + +class TestWritesDecorator: + @pytest.mark.asyncio + async def test_writes_allows_owner(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + repo = _FakeRepo(identity=principal) + + await repo.save(resource) + assert repo._resource is resource + + @pytest.mark.asyncio + async def test_writes_denies_non_owner(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError): + await repo.save(resource) + + @pytest.mark.asyncio + async def test_writes_checks_before_execution(self) -> None: + """Write check happens before the method body runs.""" + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError): + await repo.save(resource) + + # Method body never executed — _resource still None + assert repo._resource is None + + @pytest.mark.asyncio + async def test_writes_allows_system(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=System()) + + await repo.save(resource) + assert repo._resource is resource + + @pytest.mark.asyncio + async def test_writes_denies_anonymous(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous()) + + with pytest.raises(AuthorizationError, match="Authentication required"): + await repo.save(resource) diff --git a/server/tests/unit/domain/shared/authorization/test_gate.py b/server/tests/unit/domain/shared/authorization/test_gate.py new file mode 100644 index 0000000..53db48a --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_gate.py @@ -0,0 +1,47 @@ +"""Tests for gate module: Gate base class, Public, AtLeast, factory functions.""" + +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import AtLeast, Gate, Public, at_least, public + + +class TestGateHierarchy: + def test_public_is_gate(self) -> None: + assert isinstance(Public(), Gate) + + def test_at_least_is_gate(self) -> None: + assert isinstance(AtLeast(role=Role.ADMIN), Gate) + + +class TestPublic: + def test_public_returns_public_instance(self) -> None: + assert isinstance(public(), Public) + + def test_public_always_returns_same_object(self) -> None: + assert public() is public() + + +class TestAtLeast: + def test_at_least_creates_dataclass(self) -> None: + gate = at_least(Role.ADMIN) + assert isinstance(gate, AtLeast) + assert gate.role is Role.ADMIN + + def test_at_least_different_roles(self) -> None: + depositor_gate = at_least(Role.DEPOSITOR) + admin_gate = at_least(Role.ADMIN) + assert depositor_gate.role is Role.DEPOSITOR + assert admin_gate.role is Role.ADMIN + + def test_at_least_is_frozen(self) -> None: + gate = at_least(Role.ADMIN) + assert hash(gate) is not None # frozen dataclass is hashable + + def test_at_least_equality(self) -> None: + gate1 = at_least(Role.ADMIN) + gate2 = at_least(Role.ADMIN) + assert gate1 == gate2 + + def test_at_least_inequality(self) -> None: + gate1 = at_least(Role.ADMIN) + gate2 = at_least(Role.DEPOSITOR) + assert gate1 != gate2 diff --git a/server/tests/unit/domain/shared/authorization/test_guarded.py b/server/tests/unit/domain/shared/authorization/test_guarded.py deleted file mode 100644 index 0e3cac1..0000000 --- a/server/tests/unit/domain/shared/authorization/test_guarded.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Tests for Guarded[T]: T009 — generic authorization wrapper.""" - -import pytest - -from osa.domain.auth.model.principal import Principal -from osa.domain.auth.model.role import Role -from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.shared.authorization.action import Action -from osa.domain.shared.authorization.guarded import Guarded -from osa.domain.shared.authorization.policy_set import POLICY_SET -from osa.domain.shared.error import AuthorizationError - - -def _make_principal( - roles: frozenset[Role], - user_id: UserId | None = None, -) -> Principal: - uid = user_id or UserId.generate() - return Principal( - user_id=uid, - identity=ProviderIdentity(provider="test", external_id="test-ext"), - roles=roles, - ) - - -class _FakeDeposition: - def __init__(self, owner_id: UserId, status: str = "draft") -> None: - self.owner_id = owner_id - self.status = status - - -class TestGuardedCheck: - def test_check_returns_unwrapped_resource_on_success(self) -> None: - user_id = UserId.generate() - principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) - dep = _FakeDeposition(owner_id=user_id) - - guarded = Guarded(dep, principal, POLICY_SET) - result = guarded.check(Action.DEPOSITION_READ) - - assert result is dep - assert result.status == "draft" - - def test_check_raises_on_failure(self) -> None: - principal = _make_principal(frozenset({Role.DEPOSITOR})) - dep = _FakeDeposition(owner_id=UserId.generate()) # different owner - - guarded = Guarded(dep, principal, POLICY_SET) - - with pytest.raises(AuthorizationError): - guarded.check(Action.DEPOSITION_SUBMIT) - - -class TestGuardedNoProxy: - def test_no_attribute_access_proxy(self) -> None: - user_id = UserId.generate() - principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) - dep = _FakeDeposition(owner_id=user_id) - - guarded = Guarded(dep, principal, POLICY_SET) - - # Accessing .status on Guarded[Deposition] should raise AttributeError - with pytest.raises(AttributeError): - _ = guarded.status # type: ignore[attr-defined] diff --git a/server/tests/unit/domain/shared/authorization/test_identity.py b/server/tests/unit/domain/shared/authorization/test_identity.py new file mode 100644 index 0000000..1e8c326 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_identity.py @@ -0,0 +1,40 @@ +"""Tests for Identity hierarchy: Anonymous, System, Principal subclassing.""" + +from osa.domain.auth.model.identity import Anonymous, Identity, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId + + +class TestIdentityHierarchy: + def test_anonymous_is_identity(self) -> None: + anon = Anonymous() + assert isinstance(anon, Identity) + + def test_system_is_identity(self) -> None: + system = System() + assert isinstance(system, Identity) + + def test_principal_is_identity(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + assert isinstance(principal, Identity) + + def test_anonymous_is_not_principal(self) -> None: + anon = Anonymous() + assert not isinstance(anon, Principal) + + def test_system_is_not_principal(self) -> None: + system = System() + assert not isinstance(system, Principal) + + def test_anonymous_is_frozen(self) -> None: + anon = Anonymous() + assert hash(anon) is not None # frozen dataclass is hashable + + def test_system_is_frozen(self) -> None: + system = System() + assert hash(system) is not None diff --git a/server/tests/unit/domain/shared/authorization/test_policy.py b/server/tests/unit/domain/shared/authorization/test_policy.py deleted file mode 100644 index f1f22ca..0000000 --- a/server/tests/unit/domain/shared/authorization/test_policy.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Tests for Policy composition: T011 — composable handler-level policies.""" - -from osa.domain.auth.model.principal import Principal -from osa.domain.auth.model.role import Role -from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.shared.authorization.policy import ( - AllOf, - RequiresRole, - requires_any_role, - requires_role, -) - - -def _make_principal(roles: frozenset[Role]) -> Principal: - return Principal( - user_id=UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="test-ext"), - roles=roles, - ) - - -class TestRequiresRole: - def test_admin_policy_denies_depositor(self) -> None: - policy = requires_role(Role.ADMIN) - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - - assert policy.evaluate(depositor) is False - - def test_admin_policy_allows_admin(self) -> None: - policy = requires_role(Role.ADMIN) - admin = _make_principal(frozenset({Role.ADMIN})) - - assert policy.evaluate(admin) is True - - def test_admin_policy_allows_superadmin(self) -> None: - policy = requires_role(Role.ADMIN) - superadmin = _make_principal(frozenset({Role.SUPERADMIN})) - - assert policy.evaluate(superadmin) is True - - -class TestPolicyComposition: - def test_or_operator(self) -> None: - policy = requires_role(Role.ADMIN) | requires_role(Role.CURATOR) - curator = _make_principal(frozenset({Role.CURATOR})) - - assert policy.evaluate(curator) is True - - def test_and_operator(self) -> None: - # Both must pass — depositor + curator would fail an AllOf(admin, curator) - policy = requires_role(Role.ADMIN) & requires_role(Role.CURATOR) - admin = _make_principal(frozenset({Role.ADMIN})) - - # Admin >= Curator, so both should pass via hierarchy - assert policy.evaluate(admin) is True - - def test_not_operator(self) -> None: - policy = ~requires_role(Role.ADMIN) - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - - assert policy.evaluate(depositor) is True - - def test_not_inverts(self) -> None: - policy = ~requires_role(Role.ADMIN) - admin = _make_principal(frozenset({Role.ADMIN})) - - assert policy.evaluate(admin) is False - - -class TestAllOf: - def test_all_must_pass(self) -> None: - policy = AllOf(policies=(RequiresRole(Role.DEPOSITOR), RequiresRole(Role.CURATOR))) - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - - # Depositor < Curator, so second check fails - assert policy.evaluate(depositor) is False - - -class TestRequiresAnyRole: - def test_any_role_works(self) -> None: - policy = requires_any_role(Role.ADMIN, Role.CURATOR) - curator = _make_principal(frozenset({Role.CURATOR})) - - assert policy.evaluate(curator) is True - - def test_any_role_fails_when_none_match(self) -> None: - policy = requires_any_role(Role.ADMIN, Role.SUPERADMIN) - depositor = _make_principal(frozenset({Role.DEPOSITOR})) - - assert policy.evaluate(depositor) is False diff --git a/server/tests/unit/domain/shared/authorization/test_policy_set.py b/server/tests/unit/domain/shared/authorization/test_policy_set.py deleted file mode 100644 index 238b937..0000000 --- a/server/tests/unit/domain/shared/authorization/test_policy_set.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Tests for PolicySet: T007 — declarative authorization rules.""" - -import pytest - -from osa.domain.auth.model.principal import Principal -from osa.domain.auth.model.role import Role -from osa.domain.auth.model.value import ProviderIdentity, UserId -from osa.domain.shared.authorization.action import Action -from osa.domain.shared.authorization.policy_set import POLICY_SET -from osa.domain.shared.error import AuthorizationError - - -def _make_principal( - roles: frozenset[Role], - user_id: UserId | None = None, -) -> Principal: - uid = user_id or UserId.generate() - return Principal( - user_id=uid, - identity=ProviderIdentity(provider="test", external_id="test-ext"), - roles=roles, - ) - - -class _FakeResource: - """Fake resource with owner_id for testing ownership checks.""" - - def __init__(self, owner_id: UserId) -> None: - self.owner_id = owner_id - - -class TestPolicySetOwnership: - def test_owner_can_submit_own_deposition(self) -> None: - user_id = UserId.generate() - principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) - resource = _FakeResource(owner_id=user_id) - - # Should not raise - POLICY_SET.guard(principal, Action.DEPOSITION_SUBMIT, resource) - - def test_non_owner_denied_submit(self) -> None: - principal = _make_principal(frozenset({Role.DEPOSITOR})) - resource = _FakeResource(owner_id=UserId.generate()) # different user - - with pytest.raises(AuthorizationError): - POLICY_SET.guard(principal, Action.DEPOSITION_SUBMIT, resource) - - def test_owner_can_read_own_deposition(self) -> None: - user_id = UserId.generate() - principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) - resource = _FakeResource(owner_id=user_id) - - POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) - - def test_non_owner_depositor_denied_read(self) -> None: - principal = _make_principal(frozenset({Role.DEPOSITOR})) - resource = _FakeResource(owner_id=UserId.generate()) - - with pytest.raises(AuthorizationError): - POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) - - -class TestPolicySetRoles: - def test_admin_reads_any_deposition(self) -> None: - principal = _make_principal(frozenset({Role.ADMIN})) - resource = _FakeResource(owner_id=UserId.generate()) - - # Admin >= Curator, so curator read rule (no ownership) should match - POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) - - def test_curator_reads_any_deposition(self) -> None: - principal = _make_principal(frozenset({Role.CURATOR})) - resource = _FakeResource(owner_id=UserId.generate()) - - POLICY_SET.guard(principal, Action.DEPOSITION_READ, resource) - - def test_depositor_cannot_create_schema(self) -> None: - principal = _make_principal(frozenset({Role.DEPOSITOR})) - - with pytest.raises(AuthorizationError): - POLICY_SET.guard(principal, Action.SCHEMA_CREATE) - - def test_admin_can_create_schema(self) -> None: - principal = _make_principal(frozenset({Role.ADMIN})) - - POLICY_SET.guard(principal, Action.SCHEMA_CREATE) - - -class TestPolicySetPublic: - def test_public_user_reads_records(self) -> None: - # No principal (anonymous) - POLICY_SET.guard(None, Action.RECORD_READ) - - def test_public_user_can_search(self) -> None: - POLICY_SET.guard(None, Action.SEARCH_QUERY) - - def test_public_user_reads_schemas(self) -> None: - POLICY_SET.guard(None, Action.SCHEMA_READ) - - -class TestPolicySetCoverage: - def test_validate_coverage_passes(self) -> None: - # Should not raise — all actions should be covered - POLICY_SET.validate_coverage() - - def test_validate_coverage_catches_missing(self) -> None: - from osa.domain.shared.authorization.policy_set import PolicySet, allow - from osa.domain.shared.error import ConfigurationError - - # Incomplete policy set - incomplete = PolicySet([allow(Action.RECORD_READ)]) - with pytest.raises(ConfigurationError): - incomplete.validate_coverage() diff --git a/server/tests/unit/domain/shared/authorization/test_resource_check.py b/server/tests/unit/domain/shared/authorization/test_resource_check.py new file mode 100644 index 0000000..ad1acf4 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_resource_check.py @@ -0,0 +1,139 @@ +"""Tests for ResourceCheck: evaluate() with System/Anonymous/Principal, OwnerCheck, HasRole, AnyOf.""" + +import pytest + +from osa.domain.auth.model.identity import Anonymous, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.resource import ( + AnyOf, + has_role, + owner, +) +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestResourceCheckSystemBypass: + def test_system_bypasses_owner_check(self) -> None: + check = owner() + resource = _FakeResource(owner_id=UserId.generate()) + # Should not raise + check.evaluate(System(), resource) + + def test_system_bypasses_has_role_check(self) -> None: + check = has_role(Role.SUPERADMIN) + resource = _FakeResource(owner_id=UserId.generate()) + check.evaluate(System(), resource) + + def test_system_bypasses_any_of_check(self) -> None: + check = owner() | has_role(Role.ADMIN) + resource = _FakeResource(owner_id=UserId.generate()) + check.evaluate(System(), resource) + + +class TestResourceCheckAnonymousRejection: + def test_anonymous_rejected_by_owner_check(self) -> None: + check = owner() + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="Authentication required"): + check.evaluate(Anonymous(), resource) + + def test_anonymous_rejected_by_has_role_check(self) -> None: + check = has_role(Role.DEPOSITOR) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="Authentication required"): + check.evaluate(Anonymous(), resource) + + +class TestOwnerCheck: + def test_owner_passes(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + owner().evaluate(principal, resource) + + def test_non_owner_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="not resource owner"): + owner().evaluate(principal, resource) + + def test_resource_without_owner_id_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + + class NoOwner: + pass + + with pytest.raises(AuthorizationError, match="not resource owner"): + owner().evaluate(principal, NoOwner()) + + +class TestHasRole: + def test_principal_with_sufficient_role_passes(self) -> None: + principal = _make_principal(frozenset({Role.ADMIN})) + resource = _FakeResource(owner_id=UserId.generate()) + has_role(Role.CURATOR).evaluate(principal, resource) + + def test_principal_with_exact_role_passes(self) -> None: + principal = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) + has_role(Role.CURATOR).evaluate(principal, resource) + + def test_principal_with_insufficient_role_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="requires role CURATOR"): + has_role(Role.CURATOR).evaluate(principal, resource) + + +class TestAnyOf: + def test_passes_when_first_check_passes(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + + check = owner() | has_role(Role.CURATOR) + check.evaluate(principal, resource) + + def test_passes_when_second_check_passes(self) -> None: + principal = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) # not owner + + check = owner() | has_role(Role.CURATOR) + check.evaluate(principal, resource) + + def test_fails_when_no_check_passes(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) # not owner, not curator + + check = owner() | has_role(Role.CURATOR) + with pytest.raises(AuthorizationError, match="Access denied"): + check.evaluate(principal, resource) + + +class TestOrOperator: + def test_pipe_creates_any_of(self) -> None: + check = owner() | has_role(Role.CURATOR) + assert isinstance(check, AnyOf) + + def test_chained_pipe(self) -> None: + check = owner() | has_role(Role.CURATOR) | has_role(Role.ADMIN) + assert isinstance(check, AnyOf) + assert len(check.checks) == 3 diff --git a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py index 117b695..4642f8b 100644 --- a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py +++ b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py @@ -29,7 +29,7 @@ class TestPrincipalHasRole: def test_has_role_uses_hierarchy(self) -> None: principal = Principal( user_id=UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="ext"), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), roles=frozenset({Role.ADMIN}), ) @@ -41,7 +41,7 @@ def test_has_role_uses_hierarchy(self) -> None: def test_has_role_depositor(self) -> None: principal = Principal( user_id=UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="ext"), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), roles=frozenset({Role.DEPOSITOR}), ) @@ -52,7 +52,7 @@ def test_has_role_depositor(self) -> None: def test_has_any_role(self) -> None: principal = Principal( user_id=UserId.generate(), - identity=ProviderIdentity(provider="test", external_id="ext"), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), roles=frozenset({Role.CURATOR}), ) diff --git a/server/tests/unit/domain/shared/authorization/test_startup_validation.py b/server/tests/unit/domain/shared/authorization/test_startup_validation.py index 1bd738f..3d3f8c9 100644 --- a/server/tests/unit/domain/shared/authorization/test_startup_validation.py +++ b/server/tests/unit/domain/shared/authorization/test_startup_validation.py @@ -1,33 +1,22 @@ -"""Tests for startup validation of handler __auth__ declarations — T036. +"""Tests for startup validation of handler __auth__ declarations. -Tests that all handlers either declare __auth__ or their command/query is __public__. +Tests that all handlers must declare __auth__ as public() or at_least(Role). """ import pytest from osa.domain.auth.model.principal import Principal -from osa.domain.shared.authorization.policy import requires_role from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import at_least, public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.error import ConfigurationError from osa.domain.shared.query import Query, QueryHandler from osa.domain.shared.query import Result as QueryResult -def validate_handlers() -> None: - """Scan all registered handler subclasses for __auth__ declarations. - - Raises ConfigurationError if any handler is missing __auth__ - and its command/query is not __public__. - """ - from osa.domain.shared.authorization.startup import validate_all_handlers - - validate_all_handlers() - - class TestStartupValidation: def test_validation_catches_missing_auth_on_command_handler(self) -> None: - """A CommandHandler without __auth__ on a non-public command should fail startup.""" + """A CommandHandler without __auth__ should fail startup.""" class UnprotectedCommand(Command): pass @@ -42,10 +31,10 @@ async def run(self, cmd: UnprotectedCommand) -> UnprotectedResult: from osa.domain.shared.authorization.startup import _check_handler_class with pytest.raises(ConfigurationError, match="UnprotectedHandler"): - _check_handler_class(UnprotectedHandler, UnprotectedCommand) + _check_handler_class(UnprotectedHandler) def test_validation_passes_for_protected_handler(self) -> None: - """A handler with __auth__ should pass validation.""" + """A handler with __auth__ = at_least(...) should pass validation.""" class ProtectedCommand(Command): pass @@ -54,8 +43,8 @@ class ProtectedResult(Result): pass class ProtectedHandler(CommandHandler[ProtectedCommand, ProtectedResult]): - __auth__ = requires_role(Role.ADMIN) - _principal: Principal | None = None + __auth__ = at_least(Role.ADMIN) + principal: Principal async def run(self, cmd: ProtectedCommand) -> ProtectedResult: return ProtectedResult() @@ -63,29 +52,30 @@ async def run(self, cmd: ProtectedCommand) -> ProtectedResult: from osa.domain.shared.authorization.startup import _check_handler_class # Should not raise - _check_handler_class(ProtectedHandler, ProtectedCommand) + _check_handler_class(ProtectedHandler) - def test_validation_passes_for_public_command(self) -> None: - """A handler for a __public__ command should pass even without __auth__.""" - from typing import ClassVar + def test_validation_passes_for_public_handler(self) -> None: + """A handler with __auth__ = public() should pass validation.""" class PublicCommand(Command): - __public__: ClassVar[bool] = True + pass class PublicResult(Result): pass class PublicHandler(CommandHandler[PublicCommand, PublicResult]): + __auth__ = public() + async def run(self, cmd: PublicCommand) -> PublicResult: return PublicResult() from osa.domain.shared.authorization.startup import _check_handler_class # Should not raise - _check_handler_class(PublicHandler, PublicCommand) + _check_handler_class(PublicHandler) def test_validation_catches_missing_auth_on_query_handler(self) -> None: - """A QueryHandler without __auth__ on a non-public query should fail.""" + """A QueryHandler without __auth__ should fail.""" class UnprotectedQuery(Query): pass @@ -100,4 +90,24 @@ async def run(self, cmd: UnprotectedQuery) -> UnprotectedQueryResult: from osa.domain.shared.authorization.startup import _check_handler_class with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): - _check_handler_class(UnprotectedQueryHandler, UnprotectedQuery) + _check_handler_class(UnprotectedQueryHandler) + + def test_validation_catches_invalid_auth_type(self) -> None: + """A handler with a non-Gate __auth__ should fail validation.""" + + class BadCommand(Command): + pass + + class BadResult(Result): + pass + + class BadHandler(CommandHandler[BadCommand, BadResult]): + __auth__ = "not_a_valid_gate" + + async def run(self, cmd: BadCommand) -> BadResult: + return BadResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="no __auth__ declaration"): + _check_handler_class(BadHandler) From 150561ebc436d2dcd8bb6a91dcd808982709655c Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Feb 2026 16:52:17 +0000 Subject: [PATCH 3/3] test: close authorization test coverage gaps (#55) Add 34 tests covering QueryHandler gates, error code pinning (401 vs 403), AuthProvider identity resolution, concrete handler auth configs, repo decorator edge cases, and multi-role principal behavior. --- .../unit/domain/auth/test_auth_provider.py | 142 ++++++++++++++++++ .../unit/domain/auth/test_handler_configs.py | 135 +++++++++++++++++ .../shared/authorization/test_auth_gate.py | 79 ++++++++++ .../shared/authorization/test_decorators.py | 37 +++++ .../shared/authorization/test_error_codes.py | 125 +++++++++++++++ .../authorization/test_role_hierarchy.py | 69 +++++++++ 6 files changed, 587 insertions(+) create mode 100644 server/tests/unit/domain/auth/test_auth_provider.py create mode 100644 server/tests/unit/domain/auth/test_handler_configs.py create mode 100644 server/tests/unit/domain/shared/authorization/test_error_codes.py diff --git a/server/tests/unit/domain/auth/test_auth_provider.py b/server/tests/unit/domain/auth/test_auth_provider.py new file mode 100644 index 0000000..53c1068 --- /dev/null +++ b/server/tests/unit/domain/auth/test_auth_provider.py @@ -0,0 +1,142 @@ +"""Tests for AuthProvider identity resolution (get_identity / get_principal).""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import jwt as pyjwt +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.identity import Anonymous +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.service.token import TokenService +from osa.domain.auth.util.di.provider import AuthProvider +from osa.domain.shared.error import AuthorizationError + + +def _make_jwt_config() -> JwtConfig: + return JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + + +def _make_token_service(config: JwtConfig | None = None) -> TokenService: + return TokenService(_config=config or _make_jwt_config()) + + +def _make_request(auth_header: str | None = None) -> MagicMock: + request = MagicMock() + headers: dict[str, str] = {} + if auth_header is not None: + headers["Authorization"] = auth_header + request.headers = headers + return request + + +def _make_valid_token(token_service: TokenService, user_id: UserId) -> str: + return token_service.create_access_token( + user_id=user_id, + identity=ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789"), + ) + + +def _make_role_repo(assignments: list[RoleAssignment] | None = None) -> AsyncMock: + repo = AsyncMock() + repo.get_by_user_id.return_value = assignments or [] + return repo + + +class TestGetIdentity: + @pytest.mark.asyncio + async def test_valid_jwt_returns_principal_with_roles(self) -> None: + token_service = _make_token_service() + user_id = UserId.generate() + token = _make_valid_token(token_service, user_id) + request = _make_request(f"Bearer {token}") + + assignment = RoleAssignment( + id=RoleAssignmentId.generate(), + user_id=user_id, + role=Role.CURATOR, + assigned_by=UserId.generate(), + assigned_at=datetime.now(UTC), + ) + role_repo = _make_role_repo([assignment]) + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Principal) + assert identity.user_id == user_id + assert identity.roles == frozenset({Role.CURATOR}) + + @pytest.mark.asyncio + async def test_expired_jwt_returns_anonymous(self) -> None: + config = _make_jwt_config() + token_service = _make_token_service(config) + user_id = UserId.generate() + + # Create an expired token manually + payload = { + "sub": str(user_id), + "provider": "orcid", + "external_id": "0000-0001-2345-6789", + "exp": datetime(2020, 1, 1, tzinfo=UTC), + } + token = pyjwt.encode(payload, config.secret, algorithm=config.algorithm) + request = _make_request(f"Bearer {token}") + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + @pytest.mark.asyncio + async def test_invalid_jwt_returns_anonymous(self) -> None: + token_service = _make_token_service() + request = _make_request("Bearer not-a-valid-jwt") + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + @pytest.mark.asyncio + async def test_no_auth_header_returns_anonymous(self) -> None: + token_service = _make_token_service() + request = _make_request() + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + +class TestGetPrincipal: + def test_get_principal_with_principal_returns_it(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="orcid", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + + provider = AuthProvider() + result = provider.get_principal(principal) + + assert result is principal + + def test_get_principal_with_anonymous_raises_missing_token(self) -> None: + provider = AuthProvider() + + with pytest.raises(AuthorizationError) as exc_info: + provider.get_principal(Anonymous()) + assert exc_info.value.code == "missing_token" diff --git a/server/tests/unit/domain/auth/test_handler_configs.py b/server/tests/unit/domain/auth/test_handler_configs.py new file mode 100644 index 0000000..57ce28d --- /dev/null +++ b/server/tests/unit/domain/auth/test_handler_configs.py @@ -0,0 +1,135 @@ +"""Tests for concrete handler auth configurations. + +Verifies that production handlers enforce their declared __auth__ gates +end-to-end (real handler classes, mocked services). +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.auth.command.assign_role import ( + AssignRole, + AssignRoleHandler, +) +from osa.domain.auth.command.login import ( + InitiateLogin, + InitiateLoginHandler, +) +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.service.token import TokenService +from osa.domain.deposition.command.create import ( + CreateDeposition, + CreateDepositionHandler, +) +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class TestCreateDepositionHandlerAuth: + @pytest.mark.asyncio + async def test_create_deposition_allows_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + service = AsyncMock() + handler = CreateDepositionHandler( + principal=depositor, + deposition_service=service, + ) + + result = await handler.run(CreateDeposition()) + assert result.srn is not None + + @pytest.mark.asyncio + async def test_create_deposition_rejects_unauthenticated(self) -> None: + handler = CreateDepositionHandler.__new__(CreateDepositionHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(CreateDeposition()) + assert exc_info.value.code == "missing_token" + + +class TestAssignRoleHandlerAuth: + @pytest.mark.asyncio + async def test_assign_role_allows_superadmin(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + service = AsyncMock() + # Mock the return value to match what the handler expects + from datetime import UTC, datetime + + from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId + + target_user_id = UserId.generate() + service.assign_role.return_value = RoleAssignment( + id=RoleAssignmentId.generate(), + user_id=target_user_id, + role=Role.CURATOR, + assigned_by=superadmin.user_id, + assigned_at=datetime.now(UTC), + ) + + handler = AssignRoleHandler( + principal=superadmin, + authorization_service=service, + ) + + result = await handler.run(AssignRole(user_id=str(target_user_id), role="curator")) + assert result.role == "curator" + + @pytest.mark.asyncio + async def test_assign_role_rejects_admin(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + service = AsyncMock() + handler = AssignRoleHandler( + principal=admin, + authorization_service=service, + ) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AssignRole(user_id=str(UserId.generate()), role="curator")) + assert exc_info.value.code == "access_denied" + + +class TestInitiateLoginHandlerAuth: + @pytest.mark.asyncio + async def test_public_login_handler_works_without_principal(self) -> None: + provider_registry = MagicMock() + identity_provider = MagicMock() + identity_provider.get_authorization_url.return_value = "https://example.com/auth" + provider_registry.get.return_value = identity_provider + + from osa.config import JwtConfig + + token_service = TokenService( + _config=JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + ) + + handler = InitiateLoginHandler( + provider_registry=provider_registry, + token_service=token_service, + ) + + result = await handler.run( + InitiateLogin( + callback_url="http://localhost/callback", + final_redirect_uri="http://localhost/dashboard", + provider="orcid", + ) + ) + assert result.authorization_url == "https://example.com/auth" diff --git a/server/tests/unit/domain/shared/authorization/test_auth_gate.py b/server/tests/unit/domain/shared/authorization/test_auth_gate.py index 6f33f15..4f20d2a 100644 --- a/server/tests/unit/domain/shared/authorization/test_auth_gate.py +++ b/server/tests/unit/domain/shared/authorization/test_auth_gate.py @@ -147,3 +147,82 @@ async def test_unprotected_query_handler_raises_configuration_error(self) -> Non with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): await handler.run(UnprotectedQuery(value="test")) + + +# --- Test query DTOs --- + + +class AdminOnlyQuery(Query): + value: str = "test" + + +class AdminOnlyQueryResult(QueryResult): + value: str + + +class PublicQuery(Query): + value: str = "test" + + +class PublicQueryResult(QueryResult): + value: str + + +# --- Test query handlers --- + + +class AdminOnlyQueryHandler(QueryHandler[AdminOnlyQuery, AdminOnlyQueryResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: AdminOnlyQuery) -> AdminOnlyQueryResult: + return AdminOnlyQueryResult(value=cmd.value) + + +class PublicQueryHandler(QueryHandler[PublicQuery, PublicQueryResult]): + __auth__ = public() + + async def run(self, cmd: PublicQuery) -> PublicQueryResult: + return PublicQueryResult(value=cmd.value) + + +class TestAuthGateOnQueryHandler: + @pytest.mark.asyncio + async def test_query_handler_rejects_insufficient_role(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = AdminOnlyQueryHandler(principal=depositor) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AdminOnlyQuery(value="test")) + assert exc_info.value.code == "access_denied" + + @pytest.mark.asyncio + async def test_query_handler_allows_matching_role(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + handler = AdminOnlyQueryHandler(principal=admin) + + result = await handler.run(AdminOnlyQuery(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_query_handler_allows_higher_role(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + handler = AdminOnlyQueryHandler(principal=superadmin) + + result = await handler.run(AdminOnlyQuery(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_query_handler_rejects_missing_principal(self) -> None: + handler = AdminOnlyQueryHandler.__new__(AdminOnlyQueryHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AdminOnlyQuery(value="test")) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_public_query_handler_skips_check(self) -> None: + handler = PublicQueryHandler() + + result = await handler.run(PublicQuery(value="public")) + assert result.value == "public" diff --git a/server/tests/unit/domain/shared/authorization/test_decorators.py b/server/tests/unit/domain/shared/authorization/test_decorators.py index dd4e9a1..c39e220 100644 --- a/server/tests/unit/domain/shared/authorization/test_decorators.py +++ b/server/tests/unit/domain/shared/authorization/test_decorators.py @@ -96,6 +96,24 @@ async def test_reads_denies_anonymous(self) -> None: with pytest.raises(AuthorizationError, match="Authentication required"): await repo.get("key") + @pytest.mark.asyncio + async def test_reads_denies_anonymous_with_missing_token_code(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous(), resource=resource) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.get("key") + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_reads_allows_admin_via_role_hierarchy(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=admin, resource=resource) + + result = await repo.get("key") + assert result is resource + class TestWritesDecorator: @pytest.mark.asyncio @@ -145,3 +163,22 @@ async def test_writes_denies_anonymous(self) -> None: with pytest.raises(AuthorizationError, match="Authentication required"): await repo.save(resource) + + @pytest.mark.asyncio + async def test_writes_denies_anonymous_with_missing_token_code(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous()) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.save(resource) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_writes_denies_non_owner_with_access_denied_code(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.save(resource) + assert exc_info.value.code == "access_denied" diff --git a/server/tests/unit/domain/shared/authorization/test_error_codes.py b/server/tests/unit/domain/shared/authorization/test_error_codes.py new file mode 100644 index 0000000..243bf89 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_error_codes.py @@ -0,0 +1,125 @@ +"""Tests for authorization error codes: pin 401 vs 403 mapping.""" + +import pytest + +from osa.application.api.v1.errors import map_osa_error +from osa.domain.auth.model.identity import Anonymous +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.authorization.resource import has_role, owner +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +# --- Inline handler for gate-level tests --- + + +class _GatedCommand(Command): + value: str = "test" + + +class _GatedResult(Result): + value: str + + +class _AdminGatedHandler(CommandHandler[_GatedCommand, _GatedResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: _GatedCommand) -> _GatedResult: + return _GatedResult(value=cmd.value) + + +# --- Handler gate error codes --- + + +class TestHandlerGateErrorCodes: + @pytest.mark.asyncio + async def test_missing_principal_has_missing_token_code(self) -> None: + handler = _AdminGatedHandler.__new__(_AdminGatedHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(_GatedCommand(value="test")) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_insufficient_role_has_access_denied_code(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = _AdminGatedHandler(principal=depositor) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(_GatedCommand(value="test")) + assert exc_info.value.code == "access_denied" + + +# --- Resource check error codes --- + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestResourceCheckErrorCodes: + def test_anonymous_resource_check_has_missing_token_code(self) -> None: + check = has_role(Role.CURATOR) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(Anonymous(), resource) + assert exc_info.value.code == "missing_token" + + def test_owner_check_failure_has_access_denied_code(self) -> None: + check = owner() + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + def test_has_role_check_failure_has_access_denied_code(self) -> None: + check = has_role(Role.ADMIN) + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + def test_any_of_failure_has_access_denied_code(self) -> None: + check = owner() | has_role(Role.ADMIN) + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + +# --- HTTP status code mapping --- + + +class TestErrorCodeToHttpMapping: + def test_map_osa_error_missing_token_returns_401(self) -> None: + error = AuthorizationError("Authentication required", code="missing_token") + http_exc = map_osa_error(error) + assert http_exc.status_code == 401 + + def test_map_osa_error_access_denied_returns_403(self) -> None: + error = AuthorizationError("Access denied", code="access_denied") + http_exc = map_osa_error(error) + assert http_exc.status_code == 403 diff --git a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py index 4642f8b..f0e747c 100644 --- a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py +++ b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py @@ -1,8 +1,13 @@ """Tests for Role hierarchy: T012 — numeric hierarchy comparison.""" +import pytest + from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.role import Role from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.authorization.resource import has_role +from osa.domain.shared.command import Command, CommandHandler, Result class TestRoleHierarchy: @@ -58,3 +63,67 @@ def test_has_any_role(self) -> None: assert principal.has_any_role(Role.ADMIN, Role.CURATOR) is True assert principal.has_any_role(Role.SUPERADMIN) is False + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +# --- Inline handler for gate test --- + + +class _MultiRoleCommand(Command): + value: str = "test" + + +class _MultiRoleResult(Result): + value: str + + +class _CuratorGatedHandler(CommandHandler[_MultiRoleCommand, _MultiRoleResult]): + __auth__ = at_least(Role.CURATOR) + principal: Principal + + async def run(self, cmd: _MultiRoleCommand) -> _MultiRoleResult: + return _MultiRoleResult(value=cmd.value) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestMultiRolePrincipal: + def test_multi_role_principal_has_role_uses_highest(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + + assert principal.has_role(Role.CURATOR) is True + assert principal.has_role(Role.DEPOSITOR) is True + + def test_multi_role_principal_fails_above_highest(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + + assert principal.has_role(Role.ADMIN) is False + + @pytest.mark.asyncio + async def test_multi_role_principal_at_handler_gate(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + handler = _CuratorGatedHandler(principal=principal) + + result = await handler.run(_MultiRoleCommand(value="ok")) + assert result.value == "ok" + + def test_multi_role_principal_at_resource_check(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + check = has_role(Role.CURATOR) + resource = _FakeResource(owner_id=UserId.generate()) + + # Should not raise — CURATOR satisfies has_role(CURATOR) + check.evaluate(principal, resource)