diff --git a/server/migrations/versions/add_authorization.py b/server/migrations/versions/add_authorization.py new file mode 100644 index 0000000..0a9977b --- /dev/null +++ b/server/migrations/versions/add_authorization.py @@ -0,0 +1,71 @@ +"""add_authorization + +Add role_assignments table and owner_id column to depositions. + +Revision ID: add_authorization +Revises: add_auth_tables +Create Date: 2026-02-06 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_authorization" +down_revision: Union[str, Sequence[str], None] = "add_auth_tables" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add authorization tables and columns.""" + # ROLE ASSIGNMENTS TABLE + op.create_table( + "role_assignments", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("role", sa.String(32), nullable=False), + sa.Column("assigned_by", sa.String(), nullable=False), + sa.Column("assigned_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["assigned_by"], + ["users.id"], + ), + sa.UniqueConstraint("user_id", "role", name="uq_role_assignments_user_role"), + ) + op.create_index("ix_role_assignments_user_id", "role_assignments", ["user_id"]) + + # ADD owner_id TO DEPOSITIONS (nullable initially for existing data) + op.add_column( + "depositions", + sa.Column("owner_id", sa.String(), nullable=True), + ) + op.create_foreign_key( + "fk_depositions_owner_id", + "depositions", + "users", + ["owner_id"], + ["id"], + ) + op.create_index("idx_depositions_owner_id", "depositions", ["owner_id"]) + + +def downgrade() -> None: + """Remove authorization tables and columns.""" + # DEPOSITIONS owner_id + op.drop_index("idx_depositions_owner_id", table_name="depositions") + op.drop_constraint("fk_depositions_owner_id", "depositions", type_="foreignkey") + op.drop_column("depositions", "owner_id") + + # ROLE ASSIGNMENTS + op.drop_index("ix_role_assignments_user_id", table_name="role_assignments") + op.drop_table("role_assignments") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 735e836..e9d6fef 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -6,9 +6,19 @@ from fastapi.responses import JSONResponse from osa.application.api.v1.errors import map_osa_error -from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation +from osa.application.api.v1.routes import ( + admin, + auth, + events, + health, + records, + search, + stats, + validation, +) from osa.application.di import create_container from osa.config import Config, configure_logging +from osa.domain.shared.authorization.startup import validate_all_handlers from osa.domain.shared.error import OSAError from osa.infrastructure.event.worker import WorkerPool from osa.infrastructure.source.discovery import validate_sources_at_startup @@ -42,6 +52,9 @@ def create_app() -> FastAPI: # Validate source configs at startup (fail fast with clear errors) validate_sources_at_startup(config.sources) + # Validate all handlers have authorization declarations (fail fast) + validate_all_handlers() + app_instance = FastAPI( title=config.server.name, description=config.server.description, @@ -59,6 +72,7 @@ def create_app() -> FastAPI: # Register v1 routes with /api/v1 prefix app_instance.include_router(health.router, prefix="/api/v1") + app_instance.include_router(admin.router, prefix="/api/v1") app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") app_instance.include_router(records.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/errors.py b/server/osa/application/api/v1/errors.py index 00a31ec..8a9d4cf 100644 --- a/server/osa/application/api/v1/errors.py +++ b/server/osa/application/api/v1/errors.py @@ -49,6 +49,13 @@ def map_osa_error(error: OSAError) -> HTTPException: status_code = DOMAIN_ERROR_STATUS_MAP.get(type(error), 400) if isinstance(error, ValidationError) and error.field is not None: detail["field"] = error.field + # Distinguish 401 (unauthenticated) from 403 (unauthorized) + if isinstance(error, AuthorizationError) and error.code == "missing_token": + return HTTPException( + status_code=401, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) return HTTPException(status_code=status_code, detail=detail) # Fallback for unknown OSAError subclasses diff --git a/server/osa/application/api/v1/routes/admin.py b/server/osa/application/api/v1/routes/admin.py new file mode 100644 index 0000000..66c112a --- /dev/null +++ b/server/osa/application/api/v1/routes/admin.py @@ -0,0 +1,95 @@ +"""Admin routes for role management.""" + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter, Response +from pydantic import BaseModel + +from osa.domain.auth.command.assign_role import ( + AssignRole, + AssignRoleHandler, +) +from osa.domain.auth.command.revoke_role import ( + RevokeRole, + RevokeRoleHandler, +) +from osa.domain.auth.query.get_user_roles import ( + GetUserRoles, + GetUserRolesHandler, +) + +router = APIRouter(prefix="/admin", tags=["Admin"], route_class=DishkaRoute) + + +class AssignRoleRequest(BaseModel): + """Request body for assigning a role.""" + + role: str + + +class RoleAssignmentResponse(BaseModel): + """Response for a single role assignment.""" + + id: str + user_id: str + role: str + assigned_by: str + assigned_at: str + + +class RoleAssignmentListResponse(BaseModel): + """Response listing role assignments.""" + + roles: list[RoleAssignmentResponse] + + +@router.get("/users/{user_id}/roles", response_model=RoleAssignmentListResponse) +async def list_user_roles( + user_id: str, + handler: FromDishka[GetUserRolesHandler], +) -> RoleAssignmentListResponse: + """List all roles assigned to a user. Requires SuperAdmin role.""" + result = await handler.run(GetUserRoles(user_id=user_id)) + return RoleAssignmentListResponse( + roles=[ + RoleAssignmentResponse( + id=r.id, + user_id=r.user_id, + role=r.role, + assigned_by=r.assigned_by, + assigned_at=r.assigned_at.isoformat(), + ) + for r in result.roles + ] + ) + + +@router.post( + "/users/{user_id}/roles", + response_model=RoleAssignmentResponse, + status_code=201, +) +async def assign_role( + user_id: str, + body: AssignRoleRequest, + handler: FromDishka[AssignRoleHandler], +) -> RoleAssignmentResponse: + """Assign a role to a user. Requires SuperAdmin role.""" + result = await handler.run(AssignRole(user_id=user_id, role=body.role)) + return RoleAssignmentResponse( + id=result.id, + user_id=result.user_id, + role=result.role, + assigned_by=result.assigned_by, + assigned_at=result.assigned_at.isoformat(), + ) + + +@router.delete("/users/{user_id}/roles/{role}", status_code=204) +async def revoke_role( + user_id: str, + role: str, + handler: FromDishka[RevokeRoleHandler], +) -> Response: + """Revoke a role from a user. Requires SuperAdmin role.""" + await handler.run(RevokeRole(user_id=user_id, role=role)) + return Response(status_code=204) diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py index de22b0d..d7de536 100644 --- a/server/osa/application/api/v1/routes/auth.py +++ b/server/osa/application/api/v1/routes/auth.py @@ -25,6 +25,7 @@ ) from osa.domain.auth.model.value import CurrentUser from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService from osa.domain.shared.error import InvalidStateError @@ -62,12 +63,13 @@ class LogoutResponse(BaseModel): class UserResponse(BaseModel): - """Response containing user info.""" + """Response containing user info with roles.""" id: str display_name: str | None provider: str external_id: str + roles: list[str] @router.get("/login") @@ -259,8 +261,9 @@ async def logout( async def get_me( current_user: FromDishka[CurrentUser], auth_service: FromDishka[AuthService], + role_repo: FromDishka[RoleAssignmentRepository], ) -> UserResponse: - """Get current authenticated user information.""" + """Get current authenticated user information with roles.""" user = await auth_service.get_user_by_id(current_user.user_id) if user is None: @@ -269,9 +272,13 @@ async def get_me( detail={"code": "user_not_found", "message": "User not found"}, ) + assignments = await role_repo.get_by_user_id(current_user.user_id) + roles = [a.role.name.lower() for a in assignments] + return UserResponse( id=str(user.id), display_name=user.display_name, provider=current_user.identity.provider, external_id=current_user.identity.external_id, + roles=roles, ) diff --git a/server/osa/domain/auth/command/assign_role.py b/server/osa/domain/auth/command/assign_role.py new file mode 100644 index 0000000..725831a --- /dev/null +++ b/server/osa/domain/auth/command/assign_role.py @@ -0,0 +1,49 @@ +"""AssignRole command and handler.""" + +from datetime import datetime +from uuid import UUID + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.command import Command, CommandHandler, Result + + +class AssignRole(Command): + """Command to assign a role to a user.""" + + user_id: str # UUID as string from API + role: str # Role name from API + + +class AssignRoleResult(Result): + """Result containing the created role assignment.""" + + id: str + user_id: str + role: str + assigned_by: str + assigned_at: datetime + + +class AssignRoleHandler(CommandHandler[AssignRole, AssignRoleResult]): + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal + authorization_service: AuthorizationService + + async def run(self, cmd: AssignRole) -> AssignRoleResult: + assignment = await self.authorization_service.assign_role( + user_id=UserId(UUID(cmd.user_id)), + role=Role[cmd.role.upper()], + assigned_by=self.principal.user_id, + ) + + return AssignRoleResult( + id=str(assignment.id), + user_id=str(assignment.user_id), + role=assignment.role.name.lower(), + assigned_by=str(assignment.assigned_by), + assigned_at=assignment.assigned_at, + ) diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py index a54d423..ebf2a40 100644 --- a/server/osa/domain/auth/command/login.py +++ b/server/osa/domain/auth/command/login.py @@ -7,6 +7,7 @@ from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService +from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventId @@ -31,6 +32,8 @@ class InitiateLoginResult(Result): class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]): """Handler for InitiateLogin command.""" + __auth__ = public() + provider_registry: ProviderRegistry token_service: TokenService @@ -80,6 +83,8 @@ class CompleteOAuthResult(Result): class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]): """Handler for CompleteOAuth command.""" + __auth__ = public() + auth_service: AuthService provider_registry: ProviderRegistry token_service: TokenService @@ -95,7 +100,7 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: code="unknown_provider", ) - user, identity, access_token, refresh_token = await self.auth_service.complete_oauth( + user, linked_account, access_token, refresh_token = await self.auth_service.complete_oauth( provider=identity_provider, code=cmd.code, redirect_uri=cmd.callback_url, @@ -106,16 +111,16 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: UserAuthenticated( id=EventId(uuid4()), user_id=str(user.id), - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, ) ) return CompleteOAuthResult( user_id=str(user.id), display_name=user.display_name, - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, access_token=access_token, refresh_token=refresh_token, expires_in=self.token_service.access_token_expire_seconds, diff --git a/server/osa/domain/auth/command/revoke_role.py b/server/osa/domain/auth/command/revoke_role.py new file mode 100644 index 0000000..b577841 --- /dev/null +++ b/server/osa/domain/auth/command/revoke_role.py @@ -0,0 +1,36 @@ +"""RevokeRole command and handler.""" + +from uuid import UUID + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.command import Command, CommandHandler, Result + + +class RevokeRole(Command): + """Command to revoke a role from a user.""" + + user_id: str # UUID as string from API + role: str # Role name from API + + +class RevokeRoleResult(Result): + """Empty result for successful revocation.""" + + pass + + +class RevokeRoleHandler(CommandHandler[RevokeRole, RevokeRoleResult]): + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal + authorization_service: AuthorizationService + + async def run(self, cmd: RevokeRole) -> RevokeRoleResult: + await self.authorization_service.revoke_role( + user_id=UserId(UUID(cmd.user_id)), + role=Role[cmd.role.upper()], + ) + return RevokeRoleResult() diff --git a/server/osa/domain/auth/command/token.py b/server/osa/domain/auth/command/token.py index 876c800..08dd2f0 100644 --- a/server/osa/domain/auth/command/token.py +++ b/server/osa/domain/auth/command/token.py @@ -6,6 +6,7 @@ from osa.domain.auth.event import UserLoggedOut from osa.domain.auth.service.auth import AuthService from osa.domain.auth.service.token import TokenService +from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.event import EventId from osa.domain.shared.outbox import Outbox @@ -29,6 +30,8 @@ class RefreshTokensResult(Result): class RefreshTokensHandler(CommandHandler[RefreshTokens, RefreshTokensResult]): """Handler for RefreshTokens command.""" + __auth__ = public() + auth_service: AuthService token_service: TokenService @@ -61,6 +64,8 @@ class LogoutResult(Result): class LogoutHandler(CommandHandler[Logout, LogoutResult]): """Handler for Logout command.""" + __auth__ = public() + auth_service: AuthService outbox: Outbox diff --git a/server/osa/domain/auth/model/__init__.py b/server/osa/domain/auth/model/__init__.py index e59c5bf..5850c4f 100644 --- a/server/osa/domain/auth/model/__init__.py +++ b/server/osa/domain/auth/model/__init__.py @@ -1,16 +1,22 @@ """Auth domain models.""" -from .identity import Identity +from .identity import Anonymous, Identity, System +from .linked_account import LinkedAccount +from .principal import Principal from .token import RefreshToken from .user import User from .value import IdentityId, OrcidId, RefreshTokenId, TokenFamilyId, UserId __all__ = [ + "Anonymous", "Identity", "IdentityId", + "LinkedAccount", "OrcidId", + "Principal", "RefreshToken", "RefreshTokenId", + "System", "TokenFamilyId", "User", "UserId", diff --git a/server/osa/domain/auth/model/identity.py b/server/osa/domain/auth/model/identity.py index f86090b..66e0dad 100644 --- a/server/osa/domain/auth/model/identity.py +++ b/server/osa/domain/auth/model/identity.py @@ -1,46 +1,24 @@ -"""Identity entity for the auth domain.""" - -from datetime import UTC, datetime -from typing import Any - -from osa.domain.auth.model.value import IdentityId, UserId -from osa.domain.shared.model.entity import Entity - - -class Identity(Entity): - """A link between a User and an external identity provider. - - Examples: - - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" - - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" - - Invariants: - - `(provider, external_id)` is globally unique - - `user_id` is immutable after creation - - `provider` and `external_id` are immutable after creation - """ - - id: IdentityId - user_id: UserId - provider: str - external_id: str - metadata: dict[str, Any] | None = None # Provider-specific data (name, email) - created_at: datetime - - @classmethod - def create( - cls, - user_id: UserId, - provider: str, - external_id: str, - metadata: dict[str, Any] | None = None, - ) -> "Identity": - """Create a new identity link.""" - return cls( - id=IdentityId.generate(), - user_id=user_id, - provider=provider, - external_id=external_id, - metadata=metadata, - created_at=datetime.now(UTC), - ) +"""Identity hierarchy — base types for all request identities.""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Identity: + """Base for all request identities.""" + + pass + + +@dataclass(frozen=True) +class Anonymous(Identity): + """Unauthenticated request.""" + + pass + + +@dataclass(frozen=True) +class System(Identity): + """Internal worker/background process. Bypasses resource checks.""" + + pass diff --git a/server/osa/domain/auth/model/linked_account.py b/server/osa/domain/auth/model/linked_account.py new file mode 100644 index 0000000..a8e6140 --- /dev/null +++ b/server/osa/domain/auth/model/linked_account.py @@ -0,0 +1,49 @@ +"""LinkedAccount entity for the auth domain. + +Links a User to an external identity provider (e.g. ORCiD, SAML). +""" + +from datetime import UTC, datetime +from typing import Any + +from osa.domain.auth.model.value import IdentityId, UserId +from osa.domain.shared.model.entity import Entity + + +class LinkedAccount(Entity): + """A link between a User and an external identity provider. + + Examples: + - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" + - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" + + Invariants: + - `(provider, external_id)` is globally unique + - `user_id` is immutable after creation + - `provider` and `external_id` are immutable after creation + """ + + id: IdentityId + user_id: UserId + provider: str + external_id: str + metadata: dict[str, Any] | None = None # Provider-specific data (name, email) + created_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + provider: str, + external_id: str, + metadata: dict[str, Any] | None = None, + ) -> "LinkedAccount": + """Create a new identity link.""" + return cls( + id=IdentityId.generate(), + user_id=user_id, + provider=provider, + external_id=external_id, + metadata=metadata, + created_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/model/principal.py b/server/osa/domain/auth/model/principal.py new file mode 100644 index 0000000..8767bd9 --- /dev/null +++ b/server/osa/domain/auth/model/principal.py @@ -0,0 +1,28 @@ +"""Principal — authenticated identity with roles, resolved per-request.""" + +from dataclasses import dataclass + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId + + +@dataclass(frozen=True) +class Principal(Identity): + """The authenticated identity of the current requester. + + Resolved per-request from JWT + role lookup. Immutable after creation. + Subclasses Identity so it can be used wherever Identity is expected. + """ + + user_id: UserId + provider_identity: ProviderIdentity + roles: frozenset[Role] + + def has_role(self, role: Role) -> bool: + """Check if any assigned role >= the given role (hierarchy comparison).""" + return any(r >= role for r in self.roles) + + def has_any_role(self, *roles: Role) -> bool: + """Check if any assigned role satisfies any of the given roles.""" + return any(self.has_role(r) for r in roles) diff --git a/server/osa/domain/auth/model/role.py b/server/osa/domain/auth/model/role.py new file mode 100644 index 0000000..8c1e105 --- /dev/null +++ b/server/osa/domain/auth/model/role.py @@ -0,0 +1,17 @@ +"""Role hierarchy for authorization.""" + +from enum import IntEnum + + +class Role(IntEnum): + """Hierarchical roles with numeric ordering. + + Higher values inherit all permissions of lower values. + Gaps allow future role insertion without renumbering. + """ + + PUBLIC = 0 + DEPOSITOR = 10 + CURATOR = 20 + ADMIN = 30 + SUPERADMIN = 40 diff --git a/server/osa/domain/auth/model/role_assignment.py b/server/osa/domain/auth/model/role_assignment.py new file mode 100644 index 0000000..72f6971 --- /dev/null +++ b/server/osa/domain/auth/model/role_assignment.py @@ -0,0 +1,49 @@ +"""RoleAssignment entity — tracks user-role associations.""" + +from datetime import UTC, datetime +from uuid import UUID, uuid4 + +from pydantic import RootModel + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.shared.model.entity import Entity + + +class RoleAssignmentId(RootModel[UUID]): + """Unique identifier for a RoleAssignment.""" + + @classmethod + def generate(cls) -> "RoleAssignmentId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class RoleAssignment(Entity): + """Association between a user and a role, managed by superadmins.""" + + id: RoleAssignmentId + user_id: UserId + role: Role + assigned_by: UserId + assigned_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + role: Role, + assigned_by: UserId, + ) -> "RoleAssignment": + return cls( + id=RoleAssignmentId.generate(), + user_id=user_id, + role=role, + assigned_by=assigned_by, + assigned_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/port/__init__.py b/server/osa/domain/auth/port/__init__.py index 9f67041..e51c51c 100644 --- a/server/osa/domain/auth/port/__init__.py +++ b/server/osa/domain/auth/port/__init__.py @@ -1,12 +1,12 @@ """Auth domain ports.""" from .identity_provider import IdentityInfo, IdentityProvider -from .repository import IdentityRepository, RefreshTokenRepository, UserRepository +from .repository import LinkedAccountRepository, RefreshTokenRepository, UserRepository __all__ = [ "IdentityInfo", "IdentityProvider", - "IdentityRepository", + "LinkedAccountRepository", "RefreshTokenRepository", "UserRepository", ] diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py index a2ca4a0..2c1667e 100644 --- a/server/osa/domain/auth/port/repository.py +++ b/server/osa/domain/auth/port/repository.py @@ -3,7 +3,7 @@ from abc import abstractmethod from typing import Protocol -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ( @@ -29,29 +29,29 @@ async def save(self, user: User) -> None: ... -class IdentityRepository(Port, Protocol): - """Repository for Identity entity persistence.""" +class LinkedAccountRepository(Port, Protocol): + """Repository for LinkedAccount entity persistence.""" @abstractmethod - async def get(self, identity_id: IdentityId) -> Identity | None: - """Get an identity by ID.""" + async def get(self, identity_id: IdentityId) -> LinkedAccount | None: + """Get a linked account by ID.""" ... @abstractmethod async def get_by_provider_and_external_id( self, provider: str, external_id: str - ) -> Identity | None: - """Get an identity by provider and external ID.""" + ) -> LinkedAccount | None: + """Get a linked account by provider and external ID.""" ... @abstractmethod - async def get_by_user_id(self, user_id: UserId) -> list[Identity]: - """Get all identities for a user.""" + async def get_by_user_id(self, user_id: UserId) -> list[LinkedAccount]: + """Get all linked accounts for a user.""" ... @abstractmethod - async def save(self, identity: Identity) -> None: - """Save an identity.""" + async def save(self, linked_account: LinkedAccount) -> None: + """Save a linked account.""" ... diff --git a/server/osa/domain/auth/port/role_repository.py b/server/osa/domain/auth/port/role_repository.py new file mode 100644 index 0000000..88b7eff --- /dev/null +++ b/server/osa/domain/auth/port/role_repository.py @@ -0,0 +1,33 @@ +"""Repository port for RoleAssignment persistence.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment +from osa.domain.auth.model.value import UserId +from osa.domain.shared.port import Port + + +class RoleAssignmentRepository(Port, Protocol): + """Repository for RoleAssignment entity persistence.""" + + @abstractmethod + async def get_by_user_id(self, user_id: UserId) -> list[RoleAssignment]: + """Get all role assignments for a user.""" + ... + + @abstractmethod + async def save(self, assignment: RoleAssignment) -> None: + """Save a role assignment.""" + ... + + @abstractmethod + async def delete(self, user_id: UserId, role: Role) -> bool: + """Delete a role assignment. Returns True if deleted, False if not found.""" + ... + + @abstractmethod + async def get(self, user_id: UserId, role: Role) -> RoleAssignment | None: + """Get a specific role assignment.""" + ... diff --git a/server/osa/domain/auth/query/get_user_roles.py b/server/osa/domain/auth/query/get_user_roles.py new file mode 100644 index 0000000..3b6c85d --- /dev/null +++ b/server/osa/domain/auth/query/get_user_roles.py @@ -0,0 +1,56 @@ +"""GetUserRoles query and handler.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.authorization import AuthorizationService +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +class GetUserRoles(Query): + """Query to get all roles assigned to a user.""" + + user_id: str # UUID as string from API + + +class RoleAssignmentDTO(BaseModel): + id: str + user_id: str + role: str + assigned_by: str + assigned_at: datetime + + +class GetUserRolesResult(QueryResult): + roles: list[RoleAssignmentDTO] + + +class GetUserRolesHandler(QueryHandler[GetUserRoles, GetUserRolesResult]): + __auth__ = at_least(Role.SUPERADMIN) + principal: Principal + authorization_service: AuthorizationService + + async def run(self, cmd: GetUserRoles) -> GetUserRolesResult: + assignments = await self.authorization_service.list_roles( + user_id=UserId(UUID(cmd.user_id)), + ) + + return GetUserRolesResult( + roles=[ + RoleAssignmentDTO( + id=str(a.id), + user_id=str(a.user_id), + role=a.role.name.lower(), + assigned_by=str(a.assigned_by), + assigned_at=a.assigned_at, + ) + for a in assignments + ] + ) diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py index cf1c25b..75e6d9d 100644 --- a/server/osa/domain/auth/service/auth.py +++ b/server/osa/domain/auth/service/auth.py @@ -2,13 +2,13 @@ import logging -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ProviderIdentity, TokenFamilyId, UserId from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) @@ -29,7 +29,7 @@ class AuthService(Service): """ _user_repo: UserRepository - _identity_repo: IdentityRepository + _linked_account_repo: LinkedAccountRepository _refresh_token_repo: RefreshTokenRepository _token_service: TokenService _outbox: Outbox @@ -57,7 +57,7 @@ async def complete_oauth( provider: IdentityProvider, code: str, redirect_uri: str, - ) -> tuple[User, Identity, str, str]: + ) -> tuple[User, LinkedAccount, str, str]: """Complete OAuth flow and issue tokens. Args: @@ -66,25 +66,25 @@ async def complete_oauth( redirect_uri: Must match the one used in authorization Returns: - Tuple of (user, identity, access_token, refresh_token) + Tuple of (user, linked_account, access_token, refresh_token) """ # Exchange code for identity info identity_info = await provider.exchange_code(code, redirect_uri) - # Find or create user and identity - user, identity = await self._find_or_create_user(identity_info) + # Find or create user and linked account + user, linked_account = await self._find_or_create_user(identity_info) # Create tokens - access_token, refresh_token = await self._create_tokens(user, identity) + access_token, refresh_token = await self._create_tokens(user, linked_account) logger.info( "User authenticated: user_id=%s, provider=%s, external_id=%s", user.id, - identity.provider, - identity.external_id, + linked_account.provider, + linked_account.external_id, ) - return user, identity, access_token, refresh_token + return user, linked_account, access_token, refresh_token async def refresh_tokens( self, @@ -197,10 +197,10 @@ async def get_primary_identity(self, user_id: UserId) -> ProviderIdentity | None this could be extended to support multiple identities with a designated primary. """ - identities = await self._identity_repo.get_by_user_id(user_id) - if not identities: + accounts = await self._linked_account_repo.get_by_user_id(user_id) + if not accounts: return None - first = identities[0] + first = accounts[0] return ProviderIdentity(provider=first.provider, external_id=first.external_id) async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: @@ -216,32 +216,32 @@ async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: stored = await self._refresh_token_repo.get_by_token_hash(token_hash) return stored.user_id if stored else None - async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, Identity]: + async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, LinkedAccount]: """Find existing user by identity or create new one.""" - # Check if identity already exists - existing_identity = await self._identity_repo.get_by_provider_and_external_id( + # Check if linked account already exists + existing = await self._linked_account_repo.get_by_provider_and_external_id( identity_info.provider, identity_info.external_id ) - if existing_identity: + if existing: # User exists, return them - user = await self._user_repo.get(existing_identity.user_id) + user = await self._user_repo.get(existing.user_id) if user is None: - # Orphaned identity - shouldn't happen with CASCADE - raise RuntimeError(f"Identity exists without user: {existing_identity.id}") - return user, existing_identity + # Orphaned linked account - shouldn't happen with CASCADE + raise RuntimeError(f"LinkedAccount exists without user: {existing.id}") + return user, existing - # Create new user and identity + # Create new user and linked account user = User.create(display_name=identity_info.display_name) await self._user_repo.save(user) - identity = Identity.create( + linked_account = LinkedAccount.create( user_id=user.id, provider=identity_info.provider, external_id=identity_info.external_id, metadata=identity_info.raw_data, ) - await self._identity_repo.save(identity) + await self._linked_account_repo.save(linked_account) logger.info( "New user created: user_id=%s, provider=%s", @@ -249,9 +249,9 @@ async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, identity_info.provider, ) - return user, identity + return user, linked_account - async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str]: + async def _create_tokens(self, user: User, linked_account: LinkedAccount) -> tuple[str, str]: """Create access and refresh tokens for a user.""" # Create refresh token raw_token, token_hash = self._token_service.create_refresh_token() @@ -265,8 +265,8 @@ async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str # Create access token provider_identity = ProviderIdentity( - provider=identity.provider, - external_id=identity.external_id, + provider=linked_account.provider, + external_id=linked_account.external_id, ) access_token = self._token_service.create_access_token( user_id=user.id, diff --git a/server/osa/domain/auth/service/authorization.py b/server/osa/domain/auth/service/authorization.py new file mode 100644 index 0000000..4bde1a8 --- /dev/null +++ b/server/osa/domain/auth/service/authorization.py @@ -0,0 +1,49 @@ +"""Authorization service — role assignment management.""" + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment +from osa.domain.auth.model.value import UserId +from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.service import Service + + +class AuthorizationService(Service): + """Manages role assignments for users.""" + + _role_repo: RoleAssignmentRepository + + async def assign_role( + self, + user_id: UserId, + role: Role, + assigned_by: UserId, + ) -> RoleAssignment: + """Assign a role to a user. Raises ConflictError if already assigned.""" + existing = await self._role_repo.get(user_id, role) + if existing is not None: + raise ConflictError( + f"Role {role.name} already assigned to user {user_id}", + code="role_already_assigned", + ) + + assignment = RoleAssignment.create( + user_id=user_id, + role=role, + assigned_by=assigned_by, + ) + await self._role_repo.save(assignment) + return assignment + + async def revoke_role(self, user_id: UserId, role: Role) -> None: + """Revoke a role from a user. Raises NotFoundError if not assigned.""" + deleted = await self._role_repo.delete(user_id, role) + if not deleted: + raise NotFoundError( + f"Role {role.name} not assigned to user {user_id}", + code="role_not_found", + ) + + async def list_roles(self, user_id: UserId) -> list[RoleAssignment]: + """List all role assignments for a user.""" + return await self._role_repo.get_by_user_id(user_id) diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py index 8cce9b2..efa7bf1 100644 --- a/server/osa/domain/auth/util/di/provider.py +++ b/server/osa/domain/auth/util/di/provider.py @@ -1,5 +1,6 @@ """DI provider for auth domain.""" +import logging from uuid import UUID import jwt @@ -8,23 +9,32 @@ from starlette.requests import Request from osa.config import Config +from osa.domain.auth.command.assign_role import AssignRoleHandler from osa.domain.auth.command.login import ( CompleteOAuthHandler, InitiateLoginHandler, ) +from osa.domain.auth.command.revoke_role import RevokeRoleHandler from osa.domain.auth.command.token import LogoutHandler, RefreshTokensHandler +from osa.domain.auth.model.identity import Anonymous, Identity +from osa.domain.auth.model.principal import Principal from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) +from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.domain.auth.query.get_user_roles import GetUserRolesHandler from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.authorization import AuthorizationService from osa.domain.auth.service.token import TokenService from osa.domain.shared.outbox import Outbox from osa.util.di.base import Provider from osa.util.di.scope import Scope +logger = logging.getLogger(__name__) + class AuthProvider(Provider): """DI provider for auth domain services and handlers.""" @@ -36,6 +46,14 @@ class AuthProvider(Provider): complete_oauth_handler = provide(CompleteOAuthHandler, scope=Scope.UOW) refresh_tokens_handler = provide(RefreshTokensHandler, scope=Scope.UOW) logout_handler = provide(LogoutHandler, scope=Scope.UOW) + assign_role_handler = provide(AssignRoleHandler, scope=Scope.UOW) + revoke_role_handler = provide(RevokeRoleHandler, scope=Scope.UOW) + + # Query Handlers + get_user_roles_handler = provide(GetUserRolesHandler, scope=Scope.UOW) + + # Services + authorization_service = provide(AuthorizationService, scope=Scope.UOW) @provide(scope=Scope.UOW) def get_token_service(self, config: Config) -> TokenService: @@ -46,7 +64,7 @@ def get_token_service(self, config: Config) -> TokenService: def get_auth_service( self, user_repo: UserRepository, - identity_repo: IdentityRepository, + linked_account_repo: LinkedAccountRepository, refresh_token_repo: RefreshTokenRepository, token_service: TokenService, outbox: Outbox, @@ -54,7 +72,7 @@ def get_auth_service( """Provide AuthService.""" return AuthService( _user_repo=user_repo, - _identity_repo=identity_repo, + _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _token_service=token_service, _outbox=outbox, @@ -102,3 +120,49 @@ def get_current_user( detail={"code": "invalid_token", "message": "Invalid token"}, headers={"WWW-Authenticate": "Bearer"}, ) from e + + @provide(scope=Scope.UOW) + async def get_identity( + self, + request: Request, + token_service: TokenService, + role_repo: RoleAssignmentRepository, + ) -> Identity: + """Resolve Identity from JWT + role lookup. + + Returns Anonymous for unauthenticated requests, Principal for authenticated. + """ + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return Anonymous() + + token = auth_header[7:] # Remove "Bearer " prefix + + try: + payload = token_service.validate_access_token(token) + except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): + return Anonymous() + + user_id = UserId(UUID(payload["sub"])) + + # Lookup roles from DB + assignments = await role_repo.get_by_user_id(user_id) + roles = frozenset(a.role for a in assignments) + + return Principal( + user_id=user_id, + provider_identity=ProviderIdentity( + provider=payload["provider"], + external_id=payload["external_id"], + ), + roles=roles, + ) + + @provide(scope=Scope.UOW) + def get_principal(self, identity: Identity) -> Principal: + """Extract Principal from Identity. Raises if not authenticated.""" + from osa.domain.shared.error import AuthorizationError + + if isinstance(identity, Principal): + return identity + raise AuthorizationError("Authentication required", code="missing_token") diff --git a/server/osa/domain/deposition/command/create.py b/server/osa/domain/deposition/command/create.py index 0fd73e1..c344a88 100644 --- a/server/osa/domain/deposition/command/create.py +++ b/server/osa/domain/deposition/command/create.py @@ -2,7 +2,10 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -16,6 +19,8 @@ class DepositionCreated(Result): class CreateDepositionHandler(CommandHandler[CreateDeposition, DepositionCreated]): + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService async def run(self, cmd: CreateDeposition) -> DepositionCreated: diff --git a/server/osa/domain/deposition/command/delete_files.py b/server/osa/domain/deposition/command/delete_files.py index 675c90f..1f4f0e4 100644 --- a/server/osa/domain/deposition/command/delete_files.py +++ b/server/osa/domain/deposition/command/delete_files.py @@ -1,6 +1,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.port import DepositionRepository, StoragePort +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -14,6 +17,8 @@ class DepositionFilesDeleted(Result): class DeleteDepositionFilesHandler(CommandHandler[DeleteDepositionFiles, DepositionFilesDeleted]): + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal repository: DepositionRepository storage: StoragePort diff --git a/server/osa/domain/deposition/command/submit.py b/server/osa/domain/deposition/command/submit.py index d7412f4..88d32c4 100644 --- a/server/osa/domain/deposition/command/submit.py +++ b/server/osa/domain/deposition/command/submit.py @@ -1,8 +1,11 @@ import logfire from uuid import uuid4 +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.event.submitted import DepositionSubmittedEvent from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.event import EventId from osa.domain.shared.model.srn import DepositionSRN @@ -18,6 +21,8 @@ class DepositionSubmitted(Result): class SubmitDepositionHandler(CommandHandler[SubmitDeposition, DepositionSubmitted]): + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService outbox: Outbox diff --git a/server/osa/domain/deposition/command/update.py b/server/osa/domain/deposition/command/update.py index 680afff..d0ee9b8 100644 --- a/server/osa/domain/deposition/command/update.py +++ b/server/osa/domain/deposition/command/update.py @@ -1,6 +1,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role from osa.domain.deposition.service.deposition import DepositionService +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result @@ -11,6 +14,8 @@ class DepositionUpdated(Result): ... class UpdateDepositionHandler(CommandHandler[UpdateDeposition, DepositionUpdated]): + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal deposition_service: DepositionService async def run(self, cmd: UpdateDeposition) -> DepositionUpdated: diff --git a/server/osa/domain/deposition/command/upload.py b/server/osa/domain/deposition/command/upload.py index f0d0712..5e358f8 100644 --- a/server/osa/domain/deposition/command/upload.py +++ b/server/osa/domain/deposition/command/upload.py @@ -2,6 +2,9 @@ import logfire +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import at_least from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.srn import DepositionSRN @@ -17,6 +20,9 @@ class FileUploaded(Result): class UploadFileHandler(CommandHandler[UploadFile, FileUploaded]): + __auth__ = at_least(Role.DEPOSITOR) + principal: Principal + async def run(self, cmd: UploadFile) -> FileUploaded: with logfire.span("UploadFile"): # TODO: Implement actual file storage logic diff --git a/server/osa/domain/deposition/model/aggregate.py b/server/osa/domain/deposition/model/aggregate.py index 1cae2d9..32834fc 100644 --- a/server/osa/domain/deposition/model/aggregate.py +++ b/server/osa/domain/deposition/model/aggregate.py @@ -1,5 +1,6 @@ from typing import Any, Generic, TypeVar +from osa.domain.auth.model.value import UserId from osa.domain.deposition.model.value import DepositionFile, DepositionStatus from osa.domain.shared.model.aggregate import Aggregate from osa.domain.shared.model.srn import DepositionSRN, RecordSRN @@ -14,6 +15,7 @@ class Deposition(Aggregate, Generic[T]): files: list[DepositionFile] = [] record_srn: RecordSRN | None = None provenance: dict[str, Any] = {} # Source info, provenance tracking + owner_id: UserId | None = None def remove_all_files(self) -> None: self.files = [] diff --git a/server/osa/domain/shared/authorization/__init__.py b/server/osa/domain/shared/authorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/shared/authorization/decorators.py b/server/osa/domain/shared/authorization/decorators.py new file mode 100644 index 0000000..03f2fa9 --- /dev/null +++ b/server/osa/domain/shared/authorization/decorators.py @@ -0,0 +1,46 @@ +"""Repository method decorators for resource-level authorization. + +@reads(check): After method returns, check the result (skip if None). +@writes(check): Before method runs, check the first resource arg. + +Both decorators access self._identity on the repo instance. +""" + +from collections.abc import Callable +from functools import wraps +from typing import Any + +from osa.domain.shared.authorization.resource import ResourceCheck + + +def reads(check: ResourceCheck) -> Callable: + """After method returns, evaluate the check on the result. + + If the result is None (not found), the check is skipped. + """ + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + result = await fn(self, *args, **kwargs) + if result is not None: + check.evaluate(self._identity, result) + return result + + return wrapper + + return decorator + + +def writes(check: ResourceCheck) -> Callable: + """Before method runs, evaluate the check on the first resource arg.""" + + def decorator(fn: Callable) -> Callable: + @wraps(fn) + async def wrapper(self: Any, resource: Any, *args: Any, **kwargs: Any) -> Any: + check.evaluate(self._identity, resource) + return await fn(self, resource, *args, **kwargs) + + return wrapper + + return decorator diff --git a/server/osa/domain/shared/authorization/gate.py b/server/osa/domain/shared/authorization/gate.py new file mode 100644 index 0000000..3c8c435 --- /dev/null +++ b/server/osa/domain/shared/authorization/gate.py @@ -0,0 +1,42 @@ +"""Handler-level authorization gates: public() and at_least(Role).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from osa.domain.auth.model.role import Role + + +class Gate: + """Base for handler-level authorization gates. + + Every CommandHandler/QueryHandler must declare ``__auth__: ClassVar[Gate]``. + Subclasses define specific gate behaviors (public access, role checks, etc.). + """ + + +@dataclass(frozen=True) +class Public(Gate): + """No authentication required.""" + + +@dataclass(frozen=True) +class AtLeast(Gate): + """Gate that requires the principal to have at least the given role.""" + + role: "Role" + + +_PUBLIC = Public() + + +def public() -> Public: + """Mark a handler as publicly accessible (no auth required).""" + return _PUBLIC + + +def at_least(role: "Role") -> AtLeast: + """Mark a handler as requiring at least the given role.""" + return AtLeast(role=role) diff --git a/server/osa/domain/shared/authorization/resource.py b/server/osa/domain/shared/authorization/resource.py new file mode 100644 index 0000000..44c867a --- /dev/null +++ b/server/osa/domain/shared/authorization/resource.py @@ -0,0 +1,111 @@ +"""Resource-level authorization checks for repo decorators.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from osa.domain.auth.model.role import Role + + +class ResourceCheck(ABC): + """Base class for resource-level authorization checks. + + System identities bypass all checks. Anonymous identities are rejected. + Principal identities are checked via the abstract _check method. + """ + + def evaluate(self, identity: Any, resource: Any) -> None: + """Evaluate the check against the given identity and resource. + + Raises AuthorizationError if access is denied. + """ + from osa.domain.auth.model.identity import System + from osa.domain.auth.model.principal import Principal + from osa.domain.shared.error import AuthorizationError + + if isinstance(identity, System): + return # Workers bypass all resource checks + + if not isinstance(identity, Principal): + raise AuthorizationError("Authentication required", code="missing_token") + + self._check(identity, resource) + + @abstractmethod + def _check(self, principal: Any, resource: Any) -> None: + """Check authorization for an authenticated principal. + + Args: + principal: The authenticated Principal + resource: The domain resource being accessed + + Raises: + AuthorizationError: If principal is not authorized for this resource. + """ + ... + + def __or__(self, other: ResourceCheck) -> AnyOf: + return AnyOf(checks=(self, other)) + + +@dataclass(frozen=True) +class OwnerCheck(ResourceCheck): + """Check that the principal owns the resource (resource.owner_id == principal.user_id).""" + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + owner_id = getattr(resource, "owner_id", None) + if owner_id is None or owner_id != principal.user_id: + raise AuthorizationError("Access denied: not resource owner", code="access_denied") + + +@dataclass(frozen=True) +class HasRole(ResourceCheck): + """Check that the principal has at least the given role.""" + + role: "Role" + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + if not principal.has_role(self.role): + raise AuthorizationError( + f"Access denied: requires role {self.role.name}", + code="access_denied", + ) + + +@dataclass(frozen=True) +class AnyOf(ResourceCheck): + """Check that at least one of the sub-checks passes.""" + + checks: tuple[ResourceCheck, ...] + + def _check(self, principal: Any, resource: Any) -> None: + from osa.domain.shared.error import AuthorizationError + + for check in self.checks: + try: + check._check(principal, resource) + return # At least one passed + except AuthorizationError: + continue + + raise AuthorizationError("Access denied", code="access_denied") + + def __or__(self, other: ResourceCheck) -> AnyOf: + return AnyOf(checks=(*self.checks, other)) + + +def owner() -> OwnerCheck: + """Check that the principal owns the resource.""" + return OwnerCheck() + + +def has_role(role: "Role") -> HasRole: + """Check that the principal has at least the given role.""" + return HasRole(role=role) diff --git a/server/osa/domain/shared/authorization/startup.py b/server/osa/domain/shared/authorization/startup.py new file mode 100644 index 0000000..ab414fc --- /dev/null +++ b/server/osa/domain/shared/authorization/startup.py @@ -0,0 +1,48 @@ +"""Startup validation for handler authorization declarations.""" + +import logging + +from osa.domain.shared.authorization.gate import Gate +from osa.domain.shared.command import CommandHandler +from osa.domain.shared.error import ConfigurationError +from osa.domain.shared.query import QueryHandler + +logger = logging.getLogger(__name__) + + +def _check_handler_class(handler_cls: type) -> None: + """Check a single handler class for __auth__ declaration. + + Every handler must have __auth__ set to a Gate instance. + """ + auth = getattr(handler_cls, "__auth__", None) + if not isinstance(auth, Gate): + raise ConfigurationError(f"Handler {handler_cls.__name__} has no __auth__ declaration") + + +def validate_all_handlers() -> None: + """Scan all registered CommandHandler and QueryHandler subclasses. + + Raises ConfigurationError listing all handlers missing __auth__ declarations. + """ + violations: list[str] = [] + + for handler_cls in CommandHandler.__subclasses__(): + try: + _check_handler_class(handler_cls) + except ConfigurationError as e: + violations.append(str(e)) + + for handler_cls in QueryHandler.__subclasses__(): + try: + _check_handler_class(handler_cls) + except ConfigurationError as e: + violations.append(str(e)) + + if violations: + raise ConfigurationError( + f"Authorization validation failed for {len(violations)} handler(s):\n" + + "\n".join(f" - {v}" for v in violations) + ) + + logger.info("Authorization startup validation passed for all handlers") diff --git a/server/osa/domain/shared/command.py b/server/osa/domain/shared/command.py index 948acaf..f39dee3 100644 --- a/server/osa/domain/shared/command.py +++ b/server/osa/domain/shared/command.py @@ -1,9 +1,18 @@ +"""Command and CommandHandler base classes with authorization gate.""" + +from __future__ import annotations + from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Coroutine from dataclasses import dataclass -from typing import Generic, TypeVar, dataclass_transform +from functools import wraps +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, dataclass_transform from pydantic import BaseModel +if TYPE_CHECKING: + from osa.domain.shared.authorization.gate import Gate + class Command(BaseModel): ... @@ -14,20 +23,79 @@ class Result(BaseModel): ... C = TypeVar("C", bound=Command) R = TypeVar("R", bound=Result) +# Unbound async handler method: (self, cmd) -> Coroutine -> Result +_HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] + + +def _wrap_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: + """Wrap the run() method with __auth__ gate evaluation.""" + + @wraps(original_run) + async def auth_wrapped_run(self: Any, cmd: Any) -> Any: + from osa.domain.shared.authorization.gate import AtLeast, Gate, Public + from osa.domain.shared.error import AuthorizationError, ConfigurationError + + auth_gate = getattr(type(self), "__auth__", None) + + if not isinstance(auth_gate, Gate): + raise ConfigurationError(f"Handler {type(self).__name__} has no __auth__ declaration") + + if isinstance(auth_gate, Public): + return await original_run(self, cmd) + + if isinstance(auth_gate, AtLeast): + from osa.domain.auth.model.principal import Principal + + principal = getattr(self, "principal", None) + if not isinstance(principal, Principal): + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) + + if not principal.has_role(auth_gate.role): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) + + return await original_run(self, cmd) + + raise ConfigurationError( # pragma: no cover — future gate types handled here + f"Handler {type(self).__name__} has unhandled __auth__ type: {type(auth_gate).__name__}" + ) + + return auth_wrapped_run + @dataclass_transform() class _CommandHandlerMeta(ABCMeta): - """Metaclass that combines ABC with auto-dataclass for subclasses.""" + """Metaclass that combines ABC with auto-dataclass and __auth__ gate for subclasses.""" - def __new__(mcs, name: str, bases: tuple, namespace: dict): + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]): cls = super().__new__(mcs, name, bases, namespace) if any(isinstance(b, mcs) for b in bases): - return dataclass(cls) + cls = dataclass(cls) + + # Wrap run() with auth gate + original_run = cls.__dict__.get("run") + if original_run is not None: + wrapped = _wrap_run_with_auth(cls, original_run) + cls.run = wrapped + return cls class CommandHandler(Generic[C, R], metaclass=_CommandHandlerMeta): - """Base class for command handlers. Subclasses are automatically dataclasses.""" + """Base class for command handlers. Subclasses are automatically dataclasses. + + Declare __auth__ to enforce role-based access: + class MyHandler(CommandHandler[MyCmd, MyResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + """ + + __auth__: ClassVar[Gate] @abstractmethod async def run(self, cmd: C) -> R: ... diff --git a/server/osa/domain/shared/query.py b/server/osa/domain/shared/query.py index 4bc9093..fbf8c90 100644 --- a/server/osa/domain/shared/query.py +++ b/server/osa/domain/shared/query.py @@ -1,18 +1,101 @@ -from abc import ABC, abstractmethod -from typing import Generic, TypeVar +"""Query and QueryHandler base classes with authorization gate.""" + +from __future__ import annotations + +from abc import ABCMeta, abstractmethod +from collections.abc import Callable, Coroutine +from dataclasses import dataclass +from functools import wraps +from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, dataclass_transform + from pydantic import BaseModel +if TYPE_CHECKING: + from osa.domain.shared.authorization.gate import Gate + -class Query(BaseModel, ABC): ... +class Query(BaseModel): ... -class Result(BaseModel, ABC): ... +class Result(BaseModel): ... C = TypeVar("C", bound=Query) R = TypeVar("R", bound=Result) +# Unbound async handler method: (self, cmd) -> Coroutine -> Result +_HandlerMethod = Callable[..., Coroutine[Any, Any, Any]] + + +def _wrap_query_run_with_auth(cls: type, original_run: _HandlerMethod) -> _HandlerMethod: + """Wrap the run() method with __auth__ gate evaluation.""" + + @wraps(original_run) + async def auth_wrapped_run(self: Any, cmd: Any) -> Any: + from osa.domain.shared.authorization.gate import AtLeast, Gate, Public + from osa.domain.shared.error import AuthorizationError, ConfigurationError + + auth_gate = getattr(type(self), "__auth__", None) + + if not isinstance(auth_gate, Gate): + raise ConfigurationError(f"Handler {type(self).__name__} has no __auth__ declaration") + + if isinstance(auth_gate, Public): + return await original_run(self, cmd) + + if isinstance(auth_gate, AtLeast): + from osa.domain.auth.model.principal import Principal + + principal = getattr(self, "principal", None) + if not isinstance(principal, Principal): + raise AuthorizationError( + "Authentication required", + code="missing_token", + ) + + if not principal.has_role(auth_gate.role): + raise AuthorizationError( + f"Access denied: insufficient role for {type(self).__name__}", + code="access_denied", + ) + + return await original_run(self, cmd) + + raise ConfigurationError( # pragma: no cover — future gate types handled here + f"Handler {type(self).__name__} has unhandled __auth__ type: {type(auth_gate).__name__}" + ) + + return auth_wrapped_run + + +@dataclass_transform() +class _QueryHandlerMeta(ABCMeta): + """Metaclass that combines ABC with auto-dataclass and __auth__ gate for subclasses.""" + + def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]): + cls = super().__new__(mcs, name, bases, namespace) + if any(isinstance(b, mcs) for b in bases): + cls = dataclass(cls) + + # Wrap run() with auth gate + original_run = cls.__dict__.get("run") + if original_run is not None: + wrapped = _wrap_query_run_with_auth(cls, original_run) + cls.run = wrapped + + return cls + + +class QueryHandler(Generic[C, R], metaclass=_QueryHandlerMeta): + """Base class for query handlers. Subclasses are automatically dataclasses. + + Declare __auth__ to enforce role-based access: + class MyHandler(QueryHandler[MyQuery, MyResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + """ + + __auth__: ClassVar[Gate] -class QueryHandler(ABC, Generic[C, R]): @abstractmethod - def run(self, cmd: C) -> R: ... + async def run(self, cmd: C) -> R: ... diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py index 40a354e..1fadacd 100644 --- a/server/osa/infrastructure/auth/di.py +++ b/server/osa/infrastructure/auth/di.py @@ -7,14 +7,16 @@ from osa.domain.auth.port.identity_provider import IdentityProvider from osa.domain.auth.port.provider_registry import ProviderRegistry from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) +from osa.domain.auth.port.role_repository import RoleAssignmentRepository from osa.infrastructure.auth.orcid import OrcidIdentityProvider from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry +from osa.infrastructure.auth.role_repository import PostgresRoleAssignmentRepository from osa.infrastructure.persistence.repository.auth import ( - PostgresIdentityRepository, + PostgresLinkedAccountRepository, PostgresRefreshTokenRepository, PostgresUserRepository, ) @@ -39,16 +41,21 @@ class AuthInfraProvider(Provider): scope=Scope.UOW, provides=UserRepository, ) - identity_repo = provide( - PostgresIdentityRepository, + linked_account_repo = provide( + PostgresLinkedAccountRepository, scope=Scope.UOW, - provides=IdentityRepository, + provides=LinkedAccountRepository, ) refresh_token_repo = provide( PostgresRefreshTokenRepository, scope=Scope.UOW, provides=RefreshTokenRepository, ) + role_assignment_repo = provide( + PostgresRoleAssignmentRepository, + scope=Scope.UOW, + provides=RoleAssignmentRepository, + ) @provide(scope=Scope.APP) def get_auth_http_client(self) -> httpx.AsyncClient: diff --git a/server/osa/infrastructure/auth/role_repository.py b/server/osa/infrastructure/auth/role_repository.py new file mode 100644 index 0000000..0dad962 --- /dev/null +++ b/server/osa/infrastructure/auth/role_repository.py @@ -0,0 +1,73 @@ +"""PostgreSQL implementation of RoleAssignmentRepository.""" + +from uuid import UUID + +from sqlalchemy import delete, insert, select +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId +from osa.domain.auth.model.value import UserId +from osa.domain.auth.port.role_repository import RoleAssignmentRepository +from osa.infrastructure.persistence.tables import role_assignments_table + + +def _row_to_role_assignment(row: dict) -> RoleAssignment: + """Convert a database row to a RoleAssignment model.""" + return RoleAssignment( + id=RoleAssignmentId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + role=Role[row["role"].upper()], + assigned_by=UserId(UUID(row["assigned_by"])), + assigned_at=row["assigned_at"], + ) + + +def _role_assignment_to_dict(assignment: RoleAssignment) -> dict: + """Convert a RoleAssignment model to a database row dict.""" + return { + "id": str(assignment.id), + "user_id": str(assignment.user_id), + "role": assignment.role.name.lower(), + "assigned_by": str(assignment.assigned_by), + "assigned_at": assignment.assigned_at, + } + + +class PostgresRoleAssignmentRepository(RoleAssignmentRepository): + """PostgreSQL implementation of RoleAssignmentRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get_by_user_id(self, user_id: UserId) -> list[RoleAssignment]: + stmt = select(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id) + ) + result = await self.session.execute(stmt) + rows = result.mappings().all() + return [_row_to_role_assignment(dict(row)) for row in rows] + + async def save(self, assignment: RoleAssignment) -> None: + assignment_dict = _role_assignment_to_dict(assignment) + stmt = insert(role_assignments_table).values(**assignment_dict) + await self.session.execute(stmt) + await self.session.flush() + + async def delete(self, user_id: UserId, role: Role) -> bool: + stmt = delete(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id), + role_assignments_table.c.role == role.name.lower(), + ) + result = await self.session.execute(stmt) + await self.session.flush() + return result.rowcount > 0 + + async def get(self, user_id: UserId, role: Role) -> RoleAssignment | None: + stmt = select(role_assignments_table).where( + role_assignments_table.c.user_id == str(user_id), + role_assignments_table.c.role == role.name.lower(), + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_role_assignment(dict(row)) if row else None diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index fb5e758..858f06f 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from osa.application.event import ServerStarted +from osa.domain.auth.model.identity import Identity, System from osa.domain.shared.error import SkippedEvents from osa.domain.shared.event import ( EventHandler, @@ -183,8 +184,8 @@ async def _poll_once(self) -> bool: self._state.status = WorkerStatus.CLAIMING - # Claim and process within a UOW scope - async with self._container(scope=Scope.UOW) as scope: + # Claim and process within a UOW scope (System identity for workers) + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: outbox = await scope.get(Outbox) session = await scope.get(AsyncSession) @@ -393,7 +394,7 @@ async def _emit_server_started(self) -> None: if self._container is None: return - async with self._container(scope=Scope.UOW) as scope: + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: outbox = await scope.get(Outbox) await outbox.append(ServerStarted(id=EventId(uuid4()))) session = await scope.get(AsyncSession) @@ -440,7 +441,7 @@ async def _run_schedule(self, config: "ScheduleConfig") -> None: return try: - async with self._container(scope=Scope.UOW) as scope: + async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: schedule = await scope.get(config.schedule_type) await schedule.run(**config.params) session = await scope.get(AsyncSession) @@ -484,7 +485,9 @@ async def _run_stale_claim_cleanup(self) -> None: max_timeout = max(w.config.claim_timeout for w in self._workers) # Use a scoped outbox for cleanup - async with self._container(scope=Scope.UOW) as scope: + async with self._container( + scope=Scope.UOW, context={Identity: System()} + ) as scope: outbox = await scope.get(Outbox) session = await scope.get(AsyncSession) count = await outbox.reset_stale_claims(max_timeout) diff --git a/server/osa/infrastructure/persistence/mappers/deposition.py b/server/osa/infrastructure/persistence/mappers/deposition.py index 5d051de..78eff36 100644 --- a/server/osa/infrastructure/persistence/mappers/deposition.py +++ b/server/osa/infrastructure/persistence/mappers/deposition.py @@ -1,5 +1,7 @@ from typing import Any +from uuid import UUID +from osa.domain.auth.model.value import UserId from osa.domain.deposition.model.aggregate import Deposition from osa.domain.deposition.model.value import DepositionFile, DepositionStatus from osa.domain.shared.model.srn import DepositionSRN, RecordSRN @@ -15,6 +17,7 @@ def row_to_deposition(row: dict[str, Any]) -> Deposition[dict[str, Any]]: files = [DepositionFile(**f) for f in files_data] record_id = row.get("record_id") + owner_id_raw = row.get("owner_id") return Deposition( srn=DepositionSRN.parse(row["srn"]), @@ -23,6 +26,7 @@ def row_to_deposition(row: dict[str, Any]) -> Deposition[dict[str, Any]]: files=files, provenance=row.get("provenance", {}), record_srn=RecordSRN.parse(record_id) if record_id else None, + owner_id=UserId(UUID(owner_id_raw)) if owner_id_raw else None, ) @@ -35,4 +39,5 @@ def deposition_to_dict(dep: Deposition) -> dict[str, Any]: "files": [f.model_dump(mode="json") for f in dep.files], "provenance": dep.provenance, "record_id": str(dep.record_srn) if dep.record_srn else None, + "owner_id": str(dep.owner_id) if dep.owner_id else None, } diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py index ff46522..515e713 100644 --- a/server/osa/infrastructure/persistence/repository/auth.py +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -6,7 +6,7 @@ from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import ( @@ -16,7 +16,7 @@ UserId, ) from osa.domain.auth.port.repository import ( - IdentityRepository, + LinkedAccountRepository, RefreshTokenRepository, UserRepository, ) @@ -47,9 +47,9 @@ def _user_to_dict(user: User) -> dict: } -def _row_to_identity(row: dict) -> Identity: - """Convert a database row to an Identity model.""" - return Identity( +def _row_to_linked_account(row: dict) -> LinkedAccount: + """Convert a database row to a LinkedAccount model.""" + return LinkedAccount( id=IdentityId(UUID(row["id"])), user_id=UserId(UUID(row["user_id"])), provider=row["provider"], @@ -59,15 +59,15 @@ def _row_to_identity(row: dict) -> Identity: ) -def _identity_to_dict(identity: Identity) -> dict: - """Convert an Identity model to a database row dict.""" +def _linked_account_to_dict(account: LinkedAccount) -> dict: + """Convert a LinkedAccount model to a database row dict.""" return { - "id": str(identity.id), - "user_id": str(identity.user_id), - "provider": identity.provider, - "external_id": identity.external_id, - "metadata": identity.metadata, - "created_at": identity.created_at, + "id": str(account.id), + "user_id": str(account.user_id), + "provider": account.provider, + "external_id": account.external_id, + "metadata": account.metadata, + "created_at": account.created_at, } @@ -122,47 +122,47 @@ async def save(self, user: User) -> None: await self.session.flush() -class PostgresIdentityRepository(IdentityRepository): - """PostgreSQL implementation of IdentityRepository.""" +class PostgresLinkedAccountRepository(LinkedAccountRepository): + """PostgreSQL implementation of LinkedAccountRepository.""" def __init__(self, session: AsyncSession) -> None: self.session = session - async def get(self, identity_id: IdentityId) -> Identity | None: + async def get(self, identity_id: IdentityId) -> LinkedAccount | None: stmt = select(identities_table).where(identities_table.c.id == str(identity_id)) result = await self.session.execute(stmt) row = result.mappings().first() - return _row_to_identity(dict(row)) if row else None + return _row_to_linked_account(dict(row)) if row else None async def get_by_provider_and_external_id( self, provider: str, external_id: str - ) -> Identity | None: + ) -> LinkedAccount | None: stmt = select(identities_table).where( identities_table.c.provider == provider, identities_table.c.external_id == external_id, ) result = await self.session.execute(stmt) row = result.mappings().first() - return _row_to_identity(dict(row)) if row else None + return _row_to_linked_account(dict(row)) if row else None - async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + async def get_by_user_id(self, user_id: UserId) -> list[LinkedAccount]: stmt = select(identities_table).where(identities_table.c.user_id == str(user_id)) result = await self.session.execute(stmt) rows = result.mappings().all() - return [_row_to_identity(dict(row)) for row in rows] + return [_row_to_linked_account(dict(row)) for row in rows] - async def save(self, identity: Identity) -> None: - identity_dict = _identity_to_dict(identity) - existing = await self.get(identity.id) + async def save(self, linked_account: LinkedAccount) -> None: + account_dict = _linked_account_to_dict(linked_account) + existing = await self.get(linked_account.id) if existing: stmt = ( update(identities_table) - .where(identities_table.c.id == str(identity.id)) - .values(**identity_dict) + .where(identities_table.c.id == str(linked_account.id)) + .values(**account_dict) ) else: - stmt = insert(identities_table).values(**identity_dict) + stmt = insert(identities_table).values(**account_dict) await self.session.execute(stmt) await self.session.flush() diff --git a/server/osa/infrastructure/persistence/repository/deposition.py b/server/osa/infrastructure/persistence/repository/deposition.py index 40c5ef5..5d585fa 100644 --- a/server/osa/infrastructure/persistence/repository/deposition.py +++ b/server/osa/infrastructure/persistence/repository/deposition.py @@ -1,8 +1,12 @@ from sqlalchemy import insert, select, update from sqlalchemy.ext.asyncio import AsyncSession +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.role import Role from osa.domain.deposition.model.aggregate import Deposition from osa.domain.deposition.port.repository import DepositionRepository +from osa.domain.shared.authorization.decorators import reads, writes +from osa.domain.shared.authorization.resource import has_role, owner from osa.domain.shared.model.srn import DepositionSRN from osa.infrastructure.persistence.mappers.deposition import ( row_to_deposition, @@ -14,13 +18,25 @@ class PostgresDepositionRepository(DepositionRepository): """PostgreSQL implementation of DepositionRepository.""" - def __init__(self, session: AsyncSession) -> None: + def __init__(self, session: AsyncSession, identity: Identity) -> None: self.session = session + self._identity = identity + @reads(owner() | has_role(Role.CURATOR)) + async def get(self, srn: DepositionSRN) -> Deposition | None: + stmt = select(depositions_table).where(depositions_table.c.srn == str(srn)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return row_to_deposition(dict(row)) if row else None + + @writes(owner()) async def save(self, deposition: Deposition) -> None: dep_dict = deposition_to_dict(deposition) - existing = await self.get(deposition.srn) + # Check if exists (bypass decorator for internal lookup) + stmt = select(depositions_table).where(depositions_table.c.srn == str(deposition.srn)) + result = await self.session.execute(stmt) + existing = result.mappings().first() if existing: stmt = ( @@ -33,9 +49,3 @@ async def save(self, deposition: Deposition) -> None: await self.session.execute(stmt) await self.session.flush() - - async def get(self, srn: DepositionSRN) -> Deposition | None: - stmt = select(depositions_table).where(depositions_table.c.srn == str(srn)) - result = await self.session.execute(stmt) - row = result.mappings().first() - return row_to_deposition(dict(row)) if row else None diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index f22ee64..240ca7a 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -30,11 +30,13 @@ Column("provenance", JSON, nullable=False), Column("files", JSON, nullable=False), Column("record_id", String, nullable=True), + Column("owner_id", String, ForeignKey("users.id"), nullable=True), Column("created_at", DateTime(timezone=True), nullable=False), Column("updated_at", DateTime(timezone=True), nullable=False), ) Index("idx_depositions_record_id", depositions_table.c.record_id) +Index("idx_depositions_owner_id", depositions_table.c.owner_id) # ============================================================================ @@ -175,3 +177,20 @@ Index("ix_refresh_tokens_user_id", refresh_tokens_table.c.user_id) Index("ix_refresh_tokens_token_hash", refresh_tokens_table.c.token_hash) Index("ix_refresh_tokens_family_id", refresh_tokens_table.c.family_id) + + +# ============================================================================ +# ROLE ASSIGNMENTS TABLE (Authorization) +# ============================================================================ +role_assignments_table = Table( + "role_assignments", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("role", String(32), nullable=False), + Column("assigned_by", String, ForeignKey("users.id"), nullable=False), + Column("assigned_at", DateTime(timezone=True), nullable=False), + UniqueConstraint("user_id", "role", name="uq_role_assignments_user_role"), +) + +Index("ix_role_assignments_user_id", role_assignments_table.c.user_id) diff --git a/server/tests/unit/domain/auth/test_auth_provider.py b/server/tests/unit/domain/auth/test_auth_provider.py new file mode 100644 index 0000000..53c1068 --- /dev/null +++ b/server/tests/unit/domain/auth/test_auth_provider.py @@ -0,0 +1,142 @@ +"""Tests for AuthProvider identity resolution (get_identity / get_principal).""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import jwt as pyjwt +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.identity import Anonymous +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.service.token import TokenService +from osa.domain.auth.util.di.provider import AuthProvider +from osa.domain.shared.error import AuthorizationError + + +def _make_jwt_config() -> JwtConfig: + return JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + + +def _make_token_service(config: JwtConfig | None = None) -> TokenService: + return TokenService(_config=config or _make_jwt_config()) + + +def _make_request(auth_header: str | None = None) -> MagicMock: + request = MagicMock() + headers: dict[str, str] = {} + if auth_header is not None: + headers["Authorization"] = auth_header + request.headers = headers + return request + + +def _make_valid_token(token_service: TokenService, user_id: UserId) -> str: + return token_service.create_access_token( + user_id=user_id, + identity=ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789"), + ) + + +def _make_role_repo(assignments: list[RoleAssignment] | None = None) -> AsyncMock: + repo = AsyncMock() + repo.get_by_user_id.return_value = assignments or [] + return repo + + +class TestGetIdentity: + @pytest.mark.asyncio + async def test_valid_jwt_returns_principal_with_roles(self) -> None: + token_service = _make_token_service() + user_id = UserId.generate() + token = _make_valid_token(token_service, user_id) + request = _make_request(f"Bearer {token}") + + assignment = RoleAssignment( + id=RoleAssignmentId.generate(), + user_id=user_id, + role=Role.CURATOR, + assigned_by=UserId.generate(), + assigned_at=datetime.now(UTC), + ) + role_repo = _make_role_repo([assignment]) + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Principal) + assert identity.user_id == user_id + assert identity.roles == frozenset({Role.CURATOR}) + + @pytest.mark.asyncio + async def test_expired_jwt_returns_anonymous(self) -> None: + config = _make_jwt_config() + token_service = _make_token_service(config) + user_id = UserId.generate() + + # Create an expired token manually + payload = { + "sub": str(user_id), + "provider": "orcid", + "external_id": "0000-0001-2345-6789", + "exp": datetime(2020, 1, 1, tzinfo=UTC), + } + token = pyjwt.encode(payload, config.secret, algorithm=config.algorithm) + request = _make_request(f"Bearer {token}") + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + @pytest.mark.asyncio + async def test_invalid_jwt_returns_anonymous(self) -> None: + token_service = _make_token_service() + request = _make_request("Bearer not-a-valid-jwt") + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + @pytest.mark.asyncio + async def test_no_auth_header_returns_anonymous(self) -> None: + token_service = _make_token_service() + request = _make_request() + role_repo = _make_role_repo() + + provider = AuthProvider() + identity = await provider.get_identity(request, token_service, role_repo) + + assert isinstance(identity, Anonymous) + + +class TestGetPrincipal: + def test_get_principal_with_principal_returns_it(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="orcid", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + + provider = AuthProvider() + result = provider.get_principal(principal) + + assert result is principal + + def test_get_principal_with_anonymous_raises_missing_token(self) -> None: + provider = AuthProvider() + + with pytest.raises(AuthorizationError) as exc_info: + provider.get_principal(Anonymous()) + assert exc_info.value.code == "missing_token" diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py index bef7f98..28e6fd4 100644 --- a/server/tests/unit/domain/auth/test_auth_service.py +++ b/server/tests/unit/domain/auth/test_auth_service.py @@ -7,7 +7,7 @@ import pytest from osa.config import JwtConfig -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.token import RefreshToken from osa.domain.auth.model.user import User from osa.domain.auth.model.value import IdentityId, RefreshTokenId, TokenFamilyId, UserId @@ -19,7 +19,7 @@ def make_auth_service( user_repo: AsyncMock | None = None, - identity_repo: AsyncMock | None = None, + linked_account_repo: AsyncMock | None = None, refresh_token_repo: AsyncMock | None = None, token_service: TokenService | None = None, outbox: AsyncMock | None = None, @@ -27,8 +27,8 @@ def make_auth_service( """Create an AuthService with mocked dependencies.""" if user_repo is None: user_repo = AsyncMock() - if identity_repo is None: - identity_repo = AsyncMock() + if linked_account_repo is None: + linked_account_repo = AsyncMock() if refresh_token_repo is None: refresh_token_repo = AsyncMock() if token_service is None: @@ -44,7 +44,7 @@ def make_auth_service( return AuthService( _user_repo=user_repo, - _identity_repo=identity_repo, + _linked_account_repo=linked_account_repo, _refresh_token_repo=refresh_token_repo, _token_service=token_service, _outbox=outbox, @@ -101,14 +101,16 @@ async def test_complete_oauth_creates_new_user(self): user_repo = AsyncMock() user_repo.get.return_value = None # No existing user - identity_repo = AsyncMock() - identity_repo.get_by_provider_and_external_id.return_value = None # No existing identity + linked_account_repo = AsyncMock() + linked_account_repo.get_by_provider_and_external_id.return_value = ( + None # No existing identity + ) refresh_token_repo = AsyncMock() service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) provider = make_identity_provider() @@ -119,9 +121,9 @@ async def test_complete_oauth_creates_new_user(self): redirect_uri="http://localhost/callback", ) - # Should create user and identity + # Should create user and linked account user_repo.save.assert_called_once() - identity_repo.save.assert_called_once() + linked_account_repo.save.assert_called_once() refresh_token_repo.save.assert_called_once() # Should return valid data @@ -140,7 +142,7 @@ async def test_complete_oauth_returns_existing_user(self): created_at=datetime.now(UTC), updated_at=None, ) - existing_identity = Identity( + existing_linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=existing_user.id, provider="orcid", @@ -152,14 +154,14 @@ async def test_complete_oauth_returns_existing_user(self): user_repo = AsyncMock() user_repo.get.return_value = existing_user - identity_repo = AsyncMock() - identity_repo.get_by_provider_and_external_id.return_value = existing_identity + linked_account_repo = AsyncMock() + linked_account_repo.get_by_provider_and_external_id.return_value = existing_linked_account refresh_token_repo = AsyncMock() service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) provider = make_identity_provider() @@ -170,13 +172,13 @@ async def test_complete_oauth_returns_existing_user(self): redirect_uri="http://localhost/callback", ) - # Should NOT create new user/identity + # Should NOT create new user/linked account user_repo.save.assert_not_called() - identity_repo.save.assert_not_called() + linked_account_repo.save.assert_not_called() # Should return existing user assert user.id == existing_user.id - assert identity.id == existing_identity.id + assert identity.id == existing_linked_account.id class TestAuthServiceRefreshTokens: @@ -191,7 +193,7 @@ async def test_refresh_tokens_issues_new_tokens(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -212,15 +214,15 @@ async def test_refresh_tokens_issues_new_tokens(self): user_repo = AsyncMock() user_repo.get.return_value = user - identity_repo = AsyncMock() - identity_repo.get_by_user_id.return_value = [identity] + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] refresh_token_repo = AsyncMock() refresh_token_repo.get_by_token_hash.return_value = old_token service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) @@ -245,7 +247,7 @@ async def test_refresh_tokens_revokes_old_token(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -266,15 +268,15 @@ async def test_refresh_tokens_revokes_old_token(self): user_repo = AsyncMock() user_repo.get.return_value = user - identity_repo = AsyncMock() - identity_repo.get_by_user_id.return_value = [identity] + linked_account_repo = AsyncMock() + linked_account_repo.get_by_user_id.return_value = [linked_account] refresh_token_repo = AsyncMock() refresh_token_repo.get_by_token_hash.return_value = old_token service = make_auth_service( user_repo=user_repo, - identity_repo=identity_repo, + linked_account_repo=linked_account_repo, refresh_token_repo=refresh_token_repo, ) diff --git a/server/tests/unit/domain/auth/test_command_handlers.py b/server/tests/unit/domain/auth/test_command_handlers.py index 9179e15..0dc60b0 100644 --- a/server/tests/unit/domain/auth/test_command_handlers.py +++ b/server/tests/unit/domain/auth/test_command_handlers.py @@ -20,7 +20,7 @@ RefreshTokensHandler, ) from osa.domain.auth.event import UserAuthenticated, UserLoggedOut -from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.linked_account import LinkedAccount from osa.domain.auth.model.user import User from osa.domain.auth.model.value import IdentityId, UserId from osa.domain.auth.service.token import TokenService @@ -128,7 +128,7 @@ async def test_run_emits_user_authenticated_event(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -140,7 +140,7 @@ async def test_run_emits_user_authenticated_event(self): auth_service = AsyncMock() auth_service.complete_oauth.return_value = ( user, - identity, + linked_account, "access-token", "refresh-token", ) @@ -181,7 +181,7 @@ async def test_run_returns_user_info_and_tokens(self): created_at=datetime.now(UTC), updated_at=None, ) - identity = Identity( + linked_account = LinkedAccount( id=IdentityId(uuid4()), user_id=user.id, provider="orcid", @@ -193,7 +193,7 @@ async def test_run_returns_user_info_and_tokens(self): auth_service = AsyncMock() auth_service.complete_oauth.return_value = ( user, - identity, + linked_account, "access-token", "refresh-token", ) diff --git a/server/tests/unit/domain/auth/test_handler_configs.py b/server/tests/unit/domain/auth/test_handler_configs.py new file mode 100644 index 0000000..57ce28d --- /dev/null +++ b/server/tests/unit/domain/auth/test_handler_configs.py @@ -0,0 +1,135 @@ +"""Tests for concrete handler auth configurations. + +Verifies that production handlers enforce their declared __auth__ gates +end-to-end (real handler classes, mocked services). +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.auth.command.assign_role import ( + AssignRole, + AssignRoleHandler, +) +from osa.domain.auth.command.login import ( + InitiateLogin, + InitiateLoginHandler, +) +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.service.token import TokenService +from osa.domain.deposition.command.create import ( + CreateDeposition, + CreateDepositionHandler, +) +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class TestCreateDepositionHandlerAuth: + @pytest.mark.asyncio + async def test_create_deposition_allows_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + service = AsyncMock() + handler = CreateDepositionHandler( + principal=depositor, + deposition_service=service, + ) + + result = await handler.run(CreateDeposition()) + assert result.srn is not None + + @pytest.mark.asyncio + async def test_create_deposition_rejects_unauthenticated(self) -> None: + handler = CreateDepositionHandler.__new__(CreateDepositionHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(CreateDeposition()) + assert exc_info.value.code == "missing_token" + + +class TestAssignRoleHandlerAuth: + @pytest.mark.asyncio + async def test_assign_role_allows_superadmin(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + service = AsyncMock() + # Mock the return value to match what the handler expects + from datetime import UTC, datetime + + from osa.domain.auth.model.role_assignment import RoleAssignment, RoleAssignmentId + + target_user_id = UserId.generate() + service.assign_role.return_value = RoleAssignment( + id=RoleAssignmentId.generate(), + user_id=target_user_id, + role=Role.CURATOR, + assigned_by=superadmin.user_id, + assigned_at=datetime.now(UTC), + ) + + handler = AssignRoleHandler( + principal=superadmin, + authorization_service=service, + ) + + result = await handler.run(AssignRole(user_id=str(target_user_id), role="curator")) + assert result.role == "curator" + + @pytest.mark.asyncio + async def test_assign_role_rejects_admin(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + service = AsyncMock() + handler = AssignRoleHandler( + principal=admin, + authorization_service=service, + ) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AssignRole(user_id=str(UserId.generate()), role="curator")) + assert exc_info.value.code == "access_denied" + + +class TestInitiateLoginHandlerAuth: + @pytest.mark.asyncio + async def test_public_login_handler_works_without_principal(self) -> None: + provider_registry = MagicMock() + identity_provider = MagicMock() + identity_provider.get_authorization_url.return_value = "https://example.com/auth" + provider_registry.get.return_value = identity_provider + + from osa.config import JwtConfig + + token_service = TokenService( + _config=JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + ) + + handler = InitiateLoginHandler( + provider_registry=provider_registry, + token_service=token_service, + ) + + result = await handler.run( + InitiateLogin( + callback_url="http://localhost/callback", + final_redirect_uri="http://localhost/dashboard", + provider="orcid", + ) + ) + assert result.authorization_url == "https://example.com/auth" diff --git a/server/tests/unit/domain/deposition/__init__.py b/server/tests/unit/domain/deposition/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/shared/authorization/__init__.py b/server/tests/unit/domain/shared/authorization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/shared/authorization/test_auth_gate.py b/server/tests/unit/domain/shared/authorization/test_auth_gate.py new file mode 100644 index 0000000..4f20d2a --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_auth_gate.py @@ -0,0 +1,228 @@ +"""Tests for handler __auth__ gate: metaclass wraps run() with auth check.""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.gate import at_least, public +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import AuthorizationError, ConfigurationError +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +def _make_principal(roles: frozenset[Role]) -> Principal: + return Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="test-ext"), + roles=roles, + ) + + +# --- Test command DTOs --- + + +class AdminOnlyCommand(Command): + value: str = "test" + + +class AdminOnlyResult(Result): + value: str + + +class PublicCommand(Command): + value: str = "test" + + +class PublicResult(Result): + value: str + + +# --- Test handlers --- + + +class AdminOnlyHandler(CommandHandler[AdminOnlyCommand, AdminOnlyResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: AdminOnlyCommand) -> AdminOnlyResult: + return AdminOnlyResult(value=cmd.value) + + +class PublicHandler(CommandHandler[PublicCommand, PublicResult]): + __auth__ = public() + + async def run(self, cmd: PublicCommand) -> PublicResult: + return PublicResult(value=cmd.value) + + +class UnprotectedCommand(Command): + value: str = "test" + + +class UnprotectedResult(Result): + value: str + + +class UnprotectedHandler(CommandHandler[UnprotectedCommand, UnprotectedResult]): + async def run(self, cmd: UnprotectedCommand) -> UnprotectedResult: + return UnprotectedResult(value=cmd.value) + + +class UnprotectedQuery(Query): + value: str = "test" + + +class UnprotectedQueryResult(QueryResult): + value: str + + +class UnprotectedQueryHandler(QueryHandler[UnprotectedQuery, UnprotectedQueryResult]): + async def run(self, cmd: UnprotectedQuery) -> UnprotectedQueryResult: + return UnprotectedQueryResult(value=cmd.value) + + +# --- Tests --- + + +class TestAuthGateOnCommandHandler: + @pytest.mark.asyncio + async def test_admin_handler_rejects_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = AdminOnlyHandler(principal=depositor) + + with pytest.raises(AuthorizationError): + await handler.run(AdminOnlyCommand(value="test")) + + @pytest.mark.asyncio + async def test_admin_handler_allows_admin(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + handler = AdminOnlyHandler(principal=admin) + + result = await handler.run(AdminOnlyCommand(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_admin_handler_allows_superadmin(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + handler = AdminOnlyHandler(principal=superadmin) + + result = await handler.run(AdminOnlyCommand(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_admin_handler_rejects_missing_principal(self) -> None: + # Principal field not provided — should raise AuthorizationError + handler = AdminOnlyHandler.__new__(AdminOnlyHandler) + + with pytest.raises(AuthorizationError): + await handler.run(AdminOnlyCommand(value="test")) + + @pytest.mark.asyncio + async def test_public_handler_skips_check(self) -> None: + handler = PublicHandler() + + result = await handler.run(PublicCommand(value="public")) + assert result.value == "public" + + @pytest.mark.asyncio + async def test_public_handler_works_with_principal(self) -> None: + # Public handlers work regardless of principal presence + handler = PublicHandler() + + result = await handler.run(PublicCommand(value="public")) + assert result.value == "public" + + @pytest.mark.asyncio + async def test_unprotected_command_handler_raises_configuration_error(self) -> None: + handler = UnprotectedHandler() + + with pytest.raises(ConfigurationError, match="UnprotectedHandler"): + await handler.run(UnprotectedCommand(value="test")) + + @pytest.mark.asyncio + async def test_unprotected_query_handler_raises_configuration_error(self) -> None: + handler = UnprotectedQueryHandler() + + with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): + await handler.run(UnprotectedQuery(value="test")) + + +# --- Test query DTOs --- + + +class AdminOnlyQuery(Query): + value: str = "test" + + +class AdminOnlyQueryResult(QueryResult): + value: str + + +class PublicQuery(Query): + value: str = "test" + + +class PublicQueryResult(QueryResult): + value: str + + +# --- Test query handlers --- + + +class AdminOnlyQueryHandler(QueryHandler[AdminOnlyQuery, AdminOnlyQueryResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: AdminOnlyQuery) -> AdminOnlyQueryResult: + return AdminOnlyQueryResult(value=cmd.value) + + +class PublicQueryHandler(QueryHandler[PublicQuery, PublicQueryResult]): + __auth__ = public() + + async def run(self, cmd: PublicQuery) -> PublicQueryResult: + return PublicQueryResult(value=cmd.value) + + +class TestAuthGateOnQueryHandler: + @pytest.mark.asyncio + async def test_query_handler_rejects_insufficient_role(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = AdminOnlyQueryHandler(principal=depositor) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AdminOnlyQuery(value="test")) + assert exc_info.value.code == "access_denied" + + @pytest.mark.asyncio + async def test_query_handler_allows_matching_role(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + handler = AdminOnlyQueryHandler(principal=admin) + + result = await handler.run(AdminOnlyQuery(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_query_handler_allows_higher_role(self) -> None: + superadmin = _make_principal(frozenset({Role.SUPERADMIN})) + handler = AdminOnlyQueryHandler(principal=superadmin) + + result = await handler.run(AdminOnlyQuery(value="hello")) + assert result.value == "hello" + + @pytest.mark.asyncio + async def test_query_handler_rejects_missing_principal(self) -> None: + handler = AdminOnlyQueryHandler.__new__(AdminOnlyQueryHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(AdminOnlyQuery(value="test")) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_public_query_handler_skips_check(self) -> None: + handler = PublicQueryHandler() + + result = await handler.run(PublicQuery(value="public")) + assert result.value == "public" diff --git a/server/tests/unit/domain/shared/authorization/test_decorators.py b/server/tests/unit/domain/shared/authorization/test_decorators.py new file mode 100644 index 0000000..c39e220 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_decorators.py @@ -0,0 +1,184 @@ +"""Tests for @reads/@writes repo decorators.""" + +import pytest + +from osa.domain.auth.model.identity import Anonymous, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.decorators import reads, writes +from osa.domain.shared.authorization.resource import has_role, owner +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class _FakeRepo: + """Minimal repo with _identity and decorated methods.""" + + def __init__(self, identity, resource=None): + self._identity = identity + self._resource = resource + + @reads(owner() | has_role(Role.CURATOR)) + async def get(self, key: str): + return self._resource + + @writes(owner()) + async def save(self, resource) -> None: + self._resource = resource + + +class TestReadsDecorator: + @pytest.mark.asyncio + async def test_reads_allows_owner(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + repo = _FakeRepo(identity=principal, resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_allows_curator(self) -> None: + curator = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=curator, resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_denies_non_owner_depositor(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=depositor, resource=resource) + + with pytest.raises(AuthorizationError): + await repo.get("key") + + @pytest.mark.asyncio + async def test_reads_skips_check_when_none(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + repo = _FakeRepo(identity=depositor, resource=None) + + result = await repo.get("key") + assert result is None + + @pytest.mark.asyncio + async def test_reads_allows_system(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=System(), resource=resource) + + result = await repo.get("key") + assert result is resource + + @pytest.mark.asyncio + async def test_reads_denies_anonymous(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous(), resource=resource) + + with pytest.raises(AuthorizationError, match="Authentication required"): + await repo.get("key") + + @pytest.mark.asyncio + async def test_reads_denies_anonymous_with_missing_token_code(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous(), resource=resource) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.get("key") + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_reads_allows_admin_via_role_hierarchy(self) -> None: + admin = _make_principal(frozenset({Role.ADMIN})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=admin, resource=resource) + + result = await repo.get("key") + assert result is resource + + +class TestWritesDecorator: + @pytest.mark.asyncio + async def test_writes_allows_owner(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + repo = _FakeRepo(identity=principal) + + await repo.save(resource) + assert repo._resource is resource + + @pytest.mark.asyncio + async def test_writes_denies_non_owner(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError): + await repo.save(resource) + + @pytest.mark.asyncio + async def test_writes_checks_before_execution(self) -> None: + """Write check happens before the method body runs.""" + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError): + await repo.save(resource) + + # Method body never executed — _resource still None + assert repo._resource is None + + @pytest.mark.asyncio + async def test_writes_allows_system(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=System()) + + await repo.save(resource) + assert repo._resource is resource + + @pytest.mark.asyncio + async def test_writes_denies_anonymous(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous()) + + with pytest.raises(AuthorizationError, match="Authentication required"): + await repo.save(resource) + + @pytest.mark.asyncio + async def test_writes_denies_anonymous_with_missing_token_code(self) -> None: + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=Anonymous()) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.save(resource) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_writes_denies_non_owner_with_access_denied_code(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + repo = _FakeRepo(identity=principal) + + with pytest.raises(AuthorizationError) as exc_info: + await repo.save(resource) + assert exc_info.value.code == "access_denied" diff --git a/server/tests/unit/domain/shared/authorization/test_error_codes.py b/server/tests/unit/domain/shared/authorization/test_error_codes.py new file mode 100644 index 0000000..243bf89 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_error_codes.py @@ -0,0 +1,125 @@ +"""Tests for authorization error codes: pin 401 vs 403 mapping.""" + +import pytest + +from osa.application.api.v1.errors import map_osa_error +from osa.domain.auth.model.identity import Anonymous +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.authorization.resource import has_role, owner +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +# --- Inline handler for gate-level tests --- + + +class _GatedCommand(Command): + value: str = "test" + + +class _GatedResult(Result): + value: str + + +class _AdminGatedHandler(CommandHandler[_GatedCommand, _GatedResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: _GatedCommand) -> _GatedResult: + return _GatedResult(value=cmd.value) + + +# --- Handler gate error codes --- + + +class TestHandlerGateErrorCodes: + @pytest.mark.asyncio + async def test_missing_principal_has_missing_token_code(self) -> None: + handler = _AdminGatedHandler.__new__(_AdminGatedHandler) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(_GatedCommand(value="test")) + assert exc_info.value.code == "missing_token" + + @pytest.mark.asyncio + async def test_insufficient_role_has_access_denied_code(self) -> None: + depositor = _make_principal(frozenset({Role.DEPOSITOR})) + handler = _AdminGatedHandler(principal=depositor) + + with pytest.raises(AuthorizationError) as exc_info: + await handler.run(_GatedCommand(value="test")) + assert exc_info.value.code == "access_denied" + + +# --- Resource check error codes --- + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestResourceCheckErrorCodes: + def test_anonymous_resource_check_has_missing_token_code(self) -> None: + check = has_role(Role.CURATOR) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(Anonymous(), resource) + assert exc_info.value.code == "missing_token" + + def test_owner_check_failure_has_access_denied_code(self) -> None: + check = owner() + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + def test_has_role_check_failure_has_access_denied_code(self) -> None: + check = has_role(Role.ADMIN) + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + def test_any_of_failure_has_access_denied_code(self) -> None: + check = owner() | has_role(Role.ADMIN) + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + + with pytest.raises(AuthorizationError) as exc_info: + check.evaluate(principal, resource) + assert exc_info.value.code == "access_denied" + + +# --- HTTP status code mapping --- + + +class TestErrorCodeToHttpMapping: + def test_map_osa_error_missing_token_returns_401(self) -> None: + error = AuthorizationError("Authentication required", code="missing_token") + http_exc = map_osa_error(error) + assert http_exc.status_code == 401 + + def test_map_osa_error_access_denied_returns_403(self) -> None: + error = AuthorizationError("Access denied", code="access_denied") + http_exc = map_osa_error(error) + assert http_exc.status_code == 403 diff --git a/server/tests/unit/domain/shared/authorization/test_gate.py b/server/tests/unit/domain/shared/authorization/test_gate.py new file mode 100644 index 0000000..53db48a --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_gate.py @@ -0,0 +1,47 @@ +"""Tests for gate module: Gate base class, Public, AtLeast, factory functions.""" + +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import AtLeast, Gate, Public, at_least, public + + +class TestGateHierarchy: + def test_public_is_gate(self) -> None: + assert isinstance(Public(), Gate) + + def test_at_least_is_gate(self) -> None: + assert isinstance(AtLeast(role=Role.ADMIN), Gate) + + +class TestPublic: + def test_public_returns_public_instance(self) -> None: + assert isinstance(public(), Public) + + def test_public_always_returns_same_object(self) -> None: + assert public() is public() + + +class TestAtLeast: + def test_at_least_creates_dataclass(self) -> None: + gate = at_least(Role.ADMIN) + assert isinstance(gate, AtLeast) + assert gate.role is Role.ADMIN + + def test_at_least_different_roles(self) -> None: + depositor_gate = at_least(Role.DEPOSITOR) + admin_gate = at_least(Role.ADMIN) + assert depositor_gate.role is Role.DEPOSITOR + assert admin_gate.role is Role.ADMIN + + def test_at_least_is_frozen(self) -> None: + gate = at_least(Role.ADMIN) + assert hash(gate) is not None # frozen dataclass is hashable + + def test_at_least_equality(self) -> None: + gate1 = at_least(Role.ADMIN) + gate2 = at_least(Role.ADMIN) + assert gate1 == gate2 + + def test_at_least_inequality(self) -> None: + gate1 = at_least(Role.ADMIN) + gate2 = at_least(Role.DEPOSITOR) + assert gate1 != gate2 diff --git a/server/tests/unit/domain/shared/authorization/test_identity.py b/server/tests/unit/domain/shared/authorization/test_identity.py new file mode 100644 index 0000000..1e8c326 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_identity.py @@ -0,0 +1,40 @@ +"""Tests for Identity hierarchy: Anonymous, System, Principal subclassing.""" + +from osa.domain.auth.model.identity import Anonymous, Identity, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId + + +class TestIdentityHierarchy: + def test_anonymous_is_identity(self) -> None: + anon = Anonymous() + assert isinstance(anon, Identity) + + def test_system_is_identity(self) -> None: + system = System() + assert isinstance(system, Identity) + + def test_principal_is_identity(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + assert isinstance(principal, Identity) + + def test_anonymous_is_not_principal(self) -> None: + anon = Anonymous() + assert not isinstance(anon, Principal) + + def test_system_is_not_principal(self) -> None: + system = System() + assert not isinstance(system, Principal) + + def test_anonymous_is_frozen(self) -> None: + anon = Anonymous() + assert hash(anon) is not None # frozen dataclass is hashable + + def test_system_is_frozen(self) -> None: + system = System() + assert hash(system) is not None diff --git a/server/tests/unit/domain/shared/authorization/test_resource_check.py b/server/tests/unit/domain/shared/authorization/test_resource_check.py new file mode 100644 index 0000000..ad1acf4 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_resource_check.py @@ -0,0 +1,139 @@ +"""Tests for ResourceCheck: evaluate() with System/Anonymous/Principal, OwnerCheck, HasRole, AnyOf.""" + +import pytest + +from osa.domain.auth.model.identity import Anonymous, System +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.resource import ( + AnyOf, + has_role, + owner, +) +from osa.domain.shared.error import AuthorizationError + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestResourceCheckSystemBypass: + def test_system_bypasses_owner_check(self) -> None: + check = owner() + resource = _FakeResource(owner_id=UserId.generate()) + # Should not raise + check.evaluate(System(), resource) + + def test_system_bypasses_has_role_check(self) -> None: + check = has_role(Role.SUPERADMIN) + resource = _FakeResource(owner_id=UserId.generate()) + check.evaluate(System(), resource) + + def test_system_bypasses_any_of_check(self) -> None: + check = owner() | has_role(Role.ADMIN) + resource = _FakeResource(owner_id=UserId.generate()) + check.evaluate(System(), resource) + + +class TestResourceCheckAnonymousRejection: + def test_anonymous_rejected_by_owner_check(self) -> None: + check = owner() + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="Authentication required"): + check.evaluate(Anonymous(), resource) + + def test_anonymous_rejected_by_has_role_check(self) -> None: + check = has_role(Role.DEPOSITOR) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="Authentication required"): + check.evaluate(Anonymous(), resource) + + +class TestOwnerCheck: + def test_owner_passes(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + owner().evaluate(principal, resource) + + def test_non_owner_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="not resource owner"): + owner().evaluate(principal, resource) + + def test_resource_without_owner_id_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + + class NoOwner: + pass + + with pytest.raises(AuthorizationError, match="not resource owner"): + owner().evaluate(principal, NoOwner()) + + +class TestHasRole: + def test_principal_with_sufficient_role_passes(self) -> None: + principal = _make_principal(frozenset({Role.ADMIN})) + resource = _FakeResource(owner_id=UserId.generate()) + has_role(Role.CURATOR).evaluate(principal, resource) + + def test_principal_with_exact_role_passes(self) -> None: + principal = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) + has_role(Role.CURATOR).evaluate(principal, resource) + + def test_principal_with_insufficient_role_denied(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) + with pytest.raises(AuthorizationError, match="requires role CURATOR"): + has_role(Role.CURATOR).evaluate(principal, resource) + + +class TestAnyOf: + def test_passes_when_first_check_passes(self) -> None: + user_id = UserId.generate() + principal = _make_principal(frozenset({Role.DEPOSITOR}), user_id=user_id) + resource = _FakeResource(owner_id=user_id) + + check = owner() | has_role(Role.CURATOR) + check.evaluate(principal, resource) + + def test_passes_when_second_check_passes(self) -> None: + principal = _make_principal(frozenset({Role.CURATOR})) + resource = _FakeResource(owner_id=UserId.generate()) # not owner + + check = owner() | has_role(Role.CURATOR) + check.evaluate(principal, resource) + + def test_fails_when_no_check_passes(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR})) + resource = _FakeResource(owner_id=UserId.generate()) # not owner, not curator + + check = owner() | has_role(Role.CURATOR) + with pytest.raises(AuthorizationError, match="Access denied"): + check.evaluate(principal, resource) + + +class TestOrOperator: + def test_pipe_creates_any_of(self) -> None: + check = owner() | has_role(Role.CURATOR) + assert isinstance(check, AnyOf) + + def test_chained_pipe(self) -> None: + check = owner() | has_role(Role.CURATOR) | has_role(Role.ADMIN) + assert isinstance(check, AnyOf) + assert len(check.checks) == 3 diff --git a/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py new file mode 100644 index 0000000..f0e747c --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_role_hierarchy.py @@ -0,0 +1,129 @@ +"""Tests for Role hierarchy: T012 — numeric hierarchy comparison.""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.authorization.resource import has_role +from osa.domain.shared.command import Command, CommandHandler, Result + + +class TestRoleHierarchy: + def test_admin_ge_curator(self) -> None: + assert Role.ADMIN >= Role.CURATOR + + def test_depositor_lt_admin(self) -> None: + assert not (Role.DEPOSITOR >= Role.ADMIN) + + def test_superadmin_gt_all(self) -> None: + for role in Role: + if role != Role.SUPERADMIN: + assert Role.SUPERADMIN > role + + def test_public_is_lowest(self) -> None: + for role in Role: + assert role >= Role.PUBLIC + + def test_ordering(self) -> None: + assert Role.PUBLIC < Role.DEPOSITOR < Role.CURATOR < Role.ADMIN < Role.SUPERADMIN + + +class TestPrincipalHasRole: + def test_has_role_uses_hierarchy(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.ADMIN}), + ) + + # Admin >= Curator, so has_role(CURATOR) should be True + assert principal.has_role(Role.CURATOR) is True + assert principal.has_role(Role.ADMIN) is True + assert principal.has_role(Role.SUPERADMIN) is False + + def test_has_role_depositor(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.DEPOSITOR}), + ) + + assert principal.has_role(Role.DEPOSITOR) is True + assert principal.has_role(Role.CURATOR) is False + assert principal.has_role(Role.ADMIN) is False + + def test_has_any_role(self) -> None: + principal = Principal( + user_id=UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=frozenset({Role.CURATOR}), + ) + + assert principal.has_any_role(Role.ADMIN, Role.CURATOR) is True + assert principal.has_any_role(Role.SUPERADMIN) is False + + +def _make_principal( + roles: frozenset[Role], + user_id: UserId | None = None, +) -> Principal: + return Principal( + user_id=user_id or UserId.generate(), + provider_identity=ProviderIdentity(provider="test", external_id="ext"), + roles=roles, + ) + + +# --- Inline handler for gate test --- + + +class _MultiRoleCommand(Command): + value: str = "test" + + +class _MultiRoleResult(Result): + value: str + + +class _CuratorGatedHandler(CommandHandler[_MultiRoleCommand, _MultiRoleResult]): + __auth__ = at_least(Role.CURATOR) + principal: Principal + + async def run(self, cmd: _MultiRoleCommand) -> _MultiRoleResult: + return _MultiRoleResult(value=cmd.value) + + +class _FakeResource: + def __init__(self, owner_id: UserId) -> None: + self.owner_id = owner_id + + +class TestMultiRolePrincipal: + def test_multi_role_principal_has_role_uses_highest(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + + assert principal.has_role(Role.CURATOR) is True + assert principal.has_role(Role.DEPOSITOR) is True + + def test_multi_role_principal_fails_above_highest(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + + assert principal.has_role(Role.ADMIN) is False + + @pytest.mark.asyncio + async def test_multi_role_principal_at_handler_gate(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + handler = _CuratorGatedHandler(principal=principal) + + result = await handler.run(_MultiRoleCommand(value="ok")) + assert result.value == "ok" + + def test_multi_role_principal_at_resource_check(self) -> None: + principal = _make_principal(frozenset({Role.DEPOSITOR, Role.CURATOR})) + check = has_role(Role.CURATOR) + resource = _FakeResource(owner_id=UserId.generate()) + + # Should not raise — CURATOR satisfies has_role(CURATOR) + check.evaluate(principal, resource) diff --git a/server/tests/unit/domain/shared/authorization/test_startup_validation.py b/server/tests/unit/domain/shared/authorization/test_startup_validation.py new file mode 100644 index 0000000..3d3f8c9 --- /dev/null +++ b/server/tests/unit/domain/shared/authorization/test_startup_validation.py @@ -0,0 +1,113 @@ +"""Tests for startup validation of handler __auth__ declarations. + +Tests that all handlers must declare __auth__ as public() or at_least(Role). +""" + +import pytest + +from osa.domain.auth.model.principal import Principal +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import at_least, public +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import ConfigurationError +from osa.domain.shared.query import Query, QueryHandler +from osa.domain.shared.query import Result as QueryResult + + +class TestStartupValidation: + def test_validation_catches_missing_auth_on_command_handler(self) -> None: + """A CommandHandler without __auth__ should fail startup.""" + + class UnprotectedCommand(Command): + pass + + class UnprotectedResult(Result): + pass + + class UnprotectedHandler(CommandHandler[UnprotectedCommand, UnprotectedResult]): + async def run(self, cmd: UnprotectedCommand) -> UnprotectedResult: + return UnprotectedResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="UnprotectedHandler"): + _check_handler_class(UnprotectedHandler) + + def test_validation_passes_for_protected_handler(self) -> None: + """A handler with __auth__ = at_least(...) should pass validation.""" + + class ProtectedCommand(Command): + pass + + class ProtectedResult(Result): + pass + + class ProtectedHandler(CommandHandler[ProtectedCommand, ProtectedResult]): + __auth__ = at_least(Role.ADMIN) + principal: Principal + + async def run(self, cmd: ProtectedCommand) -> ProtectedResult: + return ProtectedResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + # Should not raise + _check_handler_class(ProtectedHandler) + + def test_validation_passes_for_public_handler(self) -> None: + """A handler with __auth__ = public() should pass validation.""" + + class PublicCommand(Command): + pass + + class PublicResult(Result): + pass + + class PublicHandler(CommandHandler[PublicCommand, PublicResult]): + __auth__ = public() + + async def run(self, cmd: PublicCommand) -> PublicResult: + return PublicResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + # Should not raise + _check_handler_class(PublicHandler) + + def test_validation_catches_missing_auth_on_query_handler(self) -> None: + """A QueryHandler without __auth__ should fail.""" + + class UnprotectedQuery(Query): + pass + + class UnprotectedQueryResult(QueryResult): + pass + + class UnprotectedQueryHandler(QueryHandler[UnprotectedQuery, UnprotectedQueryResult]): + async def run(self, cmd: UnprotectedQuery) -> UnprotectedQueryResult: + return UnprotectedQueryResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="UnprotectedQueryHandler"): + _check_handler_class(UnprotectedQueryHandler) + + def test_validation_catches_invalid_auth_type(self) -> None: + """A handler with a non-Gate __auth__ should fail validation.""" + + class BadCommand(Command): + pass + + class BadResult(Result): + pass + + class BadHandler(CommandHandler[BadCommand, BadResult]): + __auth__ = "not_a_valid_gate" + + async def run(self, cmd: BadCommand) -> BadResult: + return BadResult() + + from osa.domain.shared.authorization.startup import _check_handler_class + + with pytest.raises(ConfigurationError, match="no __auth__ declaration"): + _check_handler_class(BadHandler)