Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions server/migrations/versions/add_authorization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""add_authorization

Add role_assignments table and owner_id column to depositions.

Revision ID: add_authorization
Revises: add_auth_tables
Create Date: 2026-02-06

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "add_authorization"
down_revision: Union[str, Sequence[str], None] = "add_auth_tables"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Add authorization tables and columns."""
# ROLE ASSIGNMENTS TABLE
op.create_table(
"role_assignments",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("role", sa.String(32), nullable=False),
sa.Column("assigned_by", sa.String(), nullable=False),
sa.Column("assigned_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["assigned_by"],
["users.id"],
),
sa.UniqueConstraint("user_id", "role", name="uq_role_assignments_user_role"),
)
op.create_index("ix_role_assignments_user_id", "role_assignments", ["user_id"])

# ADD owner_id TO DEPOSITIONS (nullable initially for existing data)
op.add_column(
"depositions",
sa.Column("owner_id", sa.String(), nullable=True),
)
op.create_foreign_key(
"fk_depositions_owner_id",
"depositions",
"users",
["owner_id"],
["id"],
)
op.create_index("idx_depositions_owner_id", "depositions", ["owner_id"])


def downgrade() -> None:
"""Remove authorization tables and columns."""
# DEPOSITIONS owner_id
op.drop_index("idx_depositions_owner_id", table_name="depositions")
op.drop_constraint("fk_depositions_owner_id", "depositions", type_="foreignkey")
op.drop_column("depositions", "owner_id")

# ROLE ASSIGNMENTS
op.drop_index("ix_role_assignments_user_id", table_name="role_assignments")
op.drop_table("role_assignments")
16 changes: 15 additions & 1 deletion server/osa/application/api/rest/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@
from fastapi.responses import JSONResponse

from osa.application.api.v1.errors import map_osa_error
from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation
from osa.application.api.v1.routes import (
admin,
auth,
events,
health,
records,
search,
stats,
validation,
)
from osa.application.di import create_container
from osa.config import Config, configure_logging
from osa.domain.shared.authorization.startup import validate_all_handlers
from osa.domain.shared.error import OSAError
from osa.infrastructure.event.worker import WorkerPool
from osa.infrastructure.source.discovery import validate_sources_at_startup
Expand Down Expand Up @@ -42,6 +52,9 @@ def create_app() -> FastAPI:
# Validate source configs at startup (fail fast with clear errors)
validate_sources_at_startup(config.sources)

# Validate all handlers have authorization declarations (fail fast)
validate_all_handlers()

app_instance = FastAPI(
title=config.server.name,
description=config.server.description,
Expand All @@ -59,6 +72,7 @@ def create_app() -> FastAPI:

# Register v1 routes with /api/v1 prefix
app_instance.include_router(health.router, prefix="/api/v1")
app_instance.include_router(admin.router, prefix="/api/v1")
app_instance.include_router(auth.router, prefix="/api/v1")
app_instance.include_router(events.router, prefix="/api/v1")
app_instance.include_router(records.router, prefix="/api/v1")
Expand Down
7 changes: 7 additions & 0 deletions server/osa/application/api/v1/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ def map_osa_error(error: OSAError) -> HTTPException:
status_code = DOMAIN_ERROR_STATUS_MAP.get(type(error), 400)
if isinstance(error, ValidationError) and error.field is not None:
detail["field"] = error.field
# Distinguish 401 (unauthenticated) from 403 (unauthorized)
if isinstance(error, AuthorizationError) and error.code == "missing_token":
return HTTPException(
status_code=401,
detail=detail,
headers={"WWW-Authenticate": "Bearer"},
)
return HTTPException(status_code=status_code, detail=detail)

# Fallback for unknown OSAError subclasses
Expand Down
95 changes: 95 additions & 0 deletions server/osa/application/api/v1/routes/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Admin routes for role management."""

from dishka.integrations.fastapi import DishkaRoute, FromDishka
from fastapi import APIRouter, Response
from pydantic import BaseModel

from osa.domain.auth.command.assign_role import (
AssignRole,
AssignRoleHandler,
)
from osa.domain.auth.command.revoke_role import (
RevokeRole,
RevokeRoleHandler,
)
from osa.domain.auth.query.get_user_roles import (
GetUserRoles,
GetUserRolesHandler,
)

router = APIRouter(prefix="/admin", tags=["Admin"], route_class=DishkaRoute)


class AssignRoleRequest(BaseModel):
"""Request body for assigning a role."""

role: str


class RoleAssignmentResponse(BaseModel):
"""Response for a single role assignment."""

id: str
user_id: str
role: str
assigned_by: str
assigned_at: str


class RoleAssignmentListResponse(BaseModel):
"""Response listing role assignments."""

roles: list[RoleAssignmentResponse]


@router.get("/users/{user_id}/roles", response_model=RoleAssignmentListResponse)
async def list_user_roles(
user_id: str,
handler: FromDishka[GetUserRolesHandler],
) -> RoleAssignmentListResponse:
"""List all roles assigned to a user. Requires SuperAdmin role."""
result = await handler.run(GetUserRoles(user_id=user_id))
return RoleAssignmentListResponse(
roles=[
RoleAssignmentResponse(
id=r.id,
user_id=r.user_id,
role=r.role,
assigned_by=r.assigned_by,
assigned_at=r.assigned_at.isoformat(),
)
for r in result.roles
]
)


@router.post(
"/users/{user_id}/roles",
response_model=RoleAssignmentResponse,
status_code=201,
)
async def assign_role(
user_id: str,
body: AssignRoleRequest,
handler: FromDishka[AssignRoleHandler],
) -> RoleAssignmentResponse:
"""Assign a role to a user. Requires SuperAdmin role."""
result = await handler.run(AssignRole(user_id=user_id, role=body.role))
return RoleAssignmentResponse(
id=result.id,
user_id=result.user_id,
role=result.role,
assigned_by=result.assigned_by,
assigned_at=result.assigned_at.isoformat(),
)


@router.delete("/users/{user_id}/roles/{role}", status_code=204)
async def revoke_role(
user_id: str,
role: str,
handler: FromDishka[RevokeRoleHandler],
) -> Response:
"""Revoke a role from a user. Requires SuperAdmin role."""
await handler.run(RevokeRole(user_id=user_id, role=role))
return Response(status_code=204)
11 changes: 9 additions & 2 deletions server/osa/application/api/v1/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from osa.domain.auth.model.value import CurrentUser
from osa.domain.auth.port.provider_registry import ProviderRegistry
from osa.domain.auth.port.role_repository import RoleAssignmentRepository
from osa.domain.auth.service.auth import AuthService
from osa.domain.auth.service.token import TokenService
from osa.domain.shared.error import InvalidStateError
Expand Down Expand Up @@ -62,12 +63,13 @@ class LogoutResponse(BaseModel):


class UserResponse(BaseModel):
"""Response containing user info."""
"""Response containing user info with roles."""

id: str
display_name: str | None
provider: str
external_id: str
roles: list[str]


@router.get("/login")
Expand Down Expand Up @@ -259,8 +261,9 @@ async def logout(
async def get_me(
current_user: FromDishka[CurrentUser],
auth_service: FromDishka[AuthService],
role_repo: FromDishka[RoleAssignmentRepository],
) -> UserResponse:
"""Get current authenticated user information."""
"""Get current authenticated user information with roles."""
user = await auth_service.get_user_by_id(current_user.user_id)

if user is None:
Expand All @@ -269,9 +272,13 @@ async def get_me(
detail={"code": "user_not_found", "message": "User not found"},
)

assignments = await role_repo.get_by_user_id(current_user.user_id)
roles = [a.role.name.lower() for a in assignments]

return UserResponse(
id=str(user.id),
display_name=user.display_name,
provider=current_user.identity.provider,
external_id=current_user.identity.external_id,
roles=roles,
)
49 changes: 49 additions & 0 deletions server/osa/domain/auth/command/assign_role.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""AssignRole command and handler."""

from datetime import datetime
from uuid import UUID

from osa.domain.auth.model.principal import Principal
from osa.domain.auth.model.role import Role
from osa.domain.auth.model.value import UserId
from osa.domain.auth.service.authorization import AuthorizationService
from osa.domain.shared.authorization.gate import at_least
from osa.domain.shared.command import Command, CommandHandler, Result


class AssignRole(Command):
"""Command to assign a role to a user."""

user_id: str # UUID as string from API
role: str # Role name from API


class AssignRoleResult(Result):
"""Result containing the created role assignment."""

id: str
user_id: str
role: str
assigned_by: str
assigned_at: datetime


class AssignRoleHandler(CommandHandler[AssignRole, AssignRoleResult]):
__auth__ = at_least(Role.SUPERADMIN)
principal: Principal
authorization_service: AuthorizationService

async def run(self, cmd: AssignRole) -> AssignRoleResult:
assignment = await self.authorization_service.assign_role(
user_id=UserId(UUID(cmd.user_id)),
role=Role[cmd.role.upper()],
assigned_by=self.principal.user_id,
)

return AssignRoleResult(
id=str(assignment.id),
user_id=str(assignment.user_id),
role=assignment.role.name.lower(),
assigned_by=str(assignment.assigned_by),
assigned_at=assignment.assigned_at,
)
15 changes: 10 additions & 5 deletions server/osa/domain/auth/command/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from osa.domain.auth.port.provider_registry import ProviderRegistry
from osa.domain.auth.service.auth import AuthService
from osa.domain.auth.service.token import TokenService
from osa.domain.shared.authorization.gate import public
from osa.domain.shared.command import Command, CommandHandler, Result
from osa.domain.shared.error import NotFoundError
from osa.domain.shared.event import EventId
Expand All @@ -31,6 +32,8 @@ class InitiateLoginResult(Result):
class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]):
"""Handler for InitiateLogin command."""

__auth__ = public()

provider_registry: ProviderRegistry
token_service: TokenService

Expand Down Expand Up @@ -80,6 +83,8 @@ class CompleteOAuthResult(Result):
class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]):
"""Handler for CompleteOAuth command."""

__auth__ = public()

auth_service: AuthService
provider_registry: ProviderRegistry
token_service: TokenService
Expand All @@ -95,7 +100,7 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult:
code="unknown_provider",
)

user, identity, access_token, refresh_token = await self.auth_service.complete_oauth(
user, linked_account, access_token, refresh_token = await self.auth_service.complete_oauth(
provider=identity_provider,
code=cmd.code,
redirect_uri=cmd.callback_url,
Expand All @@ -106,16 +111,16 @@ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult:
UserAuthenticated(
id=EventId(uuid4()),
user_id=str(user.id),
provider=identity.provider,
external_id=identity.external_id,
provider=linked_account.provider,
external_id=linked_account.external_id,
)
)

return CompleteOAuthResult(
user_id=str(user.id),
display_name=user.display_name,
provider=identity.provider,
external_id=identity.external_id,
provider=linked_account.provider,
external_id=linked_account.external_id,
access_token=access_token,
refresh_token=refresh_token,
expires_in=self.token_service.access_token_expire_seconds,
Expand Down
Loading
Loading