Authentication Error
+{error}
+ + Return to Home + +diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cc180ea..414bc0d 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -35,3 +35,20 @@ repos:
types: [python]
pass_filenames: false
stages: [pre-commit]
+
+ - id: web-lint
+ name: web lint
+ entry: pnpm --dir web lint
+ language: system
+ files: ^web/
+ types_or: [javascript, jsx, ts, tsx, css]
+ pass_filenames: false
+
+ - id: web-build
+ name: web build
+ entry: pnpm --dir web build
+ language: system
+ files: ^web/
+ types_or: [javascript, jsx, ts, tsx, css, json]
+ pass_filenames: false
+ stages: [pre-commit]
diff --git a/Justfile b/Justfile
index 0add92f..d750323 100644
--- a/Justfile
+++ b/Justfile
@@ -20,30 +20,10 @@ up-attached:
down:
docker compose -f deploy/docker-compose.yml down
-# View logs from all services
-logs:
- docker compose -f deploy/docker-compose.yml logs -f
-
-# View logs from a specific service
-logs-service service:
+# View logs for a service (e.g., just logs server, just logs db)
+logs service:
docker compose -f deploy/docker-compose.yml logs -f {{service}}
-# View server logs
-server-logs:
- docker compose -f deploy/docker-compose.yml logs -f server
-
-# View last N lines of server logs (default: 100)
-server-logs-tail lines="100":
- docker compose -f deploy/docker-compose.yml logs --tail {{lines}} server
-
-# View server logs with timestamps
-server-logs-time:
- docker compose -f deploy/docker-compose.yml logs -f -t server
-
-# View server logs since a time (e.g., "10m", "1h", "2024-01-01")
-server-logs-since since:
- docker compose -f deploy/docker-compose.yml logs -f --since {{since}} server
-
# Shell into the server container
server-shell:
docker compose -f deploy/docker-compose.yml exec server bash
@@ -106,10 +86,6 @@ db-up:
db-down:
docker compose -f deploy/docker-compose.yml stop db
-# View database logs
-db-logs:
- docker compose -f deploy/docker-compose.yml logs -f db
-
# Connect to PostgreSQL
db-connect:
docker compose -f deploy/docker-compose.yml exec db psql -U postgres -d osa
diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml
index 32fa427..94a9c2a 100644
--- a/deploy/docker-compose.dev.yml
+++ b/deploy/docker-compose.dev.yml
@@ -7,15 +7,19 @@ services:
context: ../server
dockerfile: Dockerfile
target: builder
+ ports:
+ - "8000:8000"
volumes:
- ../server:/app
- /app/.venv
environment:
OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa}
OSA_DATA_DIR: /data
+ OSA_CONFIG_FILE: /app/osa.yaml
OSA_LOGGING__LEVEL: ${LOG_LEVEL:-DEBUG}
WATCHFILES_FORCE_POLLING: "true"
- command: uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload
+ entrypoint: []
+ command: ["sh", "-c", "/app/.venv/bin/alembic upgrade head && /app/.venv/bin/uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload"]
healthcheck:
test: ["CMD", "curl", "--fail", "http://localhost:8000/api/v1/health"]
interval: 10s
@@ -36,4 +40,5 @@ services:
API_URL: http://server:8000
ports:
- "3000:3000"
- command: pnpm dev
+ entrypoint: []
+ command: ["pnpm", "dev"]
diff --git a/server/.env.example b/server/.env.example
new file mode 100644
index 0000000..35bfe0a
--- /dev/null
+++ b/server/.env.example
@@ -0,0 +1,76 @@
+# OSA Server Configuration - Secrets and Environment-Specific Values
+# Copy this file to .env and fill in your values
+#
+# Relationship with osa.yaml:
+# - .env: SECRETS (credentials, JWT secrets) - never commit to git
+# - osa.yaml: APPLICATION CONFIG (sources, indexes) - can be version-controlled
+#
+# Priority (highest to lowest):
+# 1. Environment variables
+# 2. .env file
+# 3. OSA_CONFIG_FILE (osa.yaml)
+#
+# Auth secrets MUST go in .env (not osa.yaml) for security
+
+# =============================================================================
+# ORCiD OAuth Configuration
+# =============================================================================
+# Register at: https://sandbox.orcid.org/developer-tools (sandbox)
+# Register at: https://orcid.org/developer-tools (production)
+
+# ORCiD OAuth client ID (e.g., APP-XXXXXXXXXXXX)
+OSA_AUTH__ORCID__CLIENT_ID=
+
+# ORCiD OAuth client secret
+OSA_AUTH__ORCID__CLIENT_SECRET=
+
+# Use ORCiD sandbox for development (default: true)
+# Set to false for production
+OSA_AUTH__ORCID__SANDBOX=true
+
+# =============================================================================
+# JWT Configuration
+# =============================================================================
+
+# JWT signing secret - minimum 32 characters required
+# Generate with: openssl rand -hex 32
+OSA_AUTH__JWT__SECRET=
+
+# Access token expiry in minutes (default: 60)
+# OSA_AUTH__JWT__ACCESS_TOKEN_EXPIRE_MINUTES=60
+
+# Refresh token expiry in days (default: 7)
+# OSA_AUTH__JWT__REFRESH_TOKEN_EXPIRE_DAYS=7
+
+# =============================================================================
+# OAuth Callback URL
+# =============================================================================
+# Must match the redirect URI registered in ORCiD developer tools
+
+# For development (default derives from request URL)
+# OSA_AUTH__CALLBACK_URL=http://localhost:8000/api/v1/auth/callback
+
+# For production
+# OSA_AUTH__CALLBACK_URL=https://your-domain.com/api/v1/auth/callback
+
+# =============================================================================
+# Frontend URL
+# =============================================================================
+# URL of the frontend application for OAuth redirects
+
+# For development
+OSA_FRONTEND__URL=http://localhost:3000
+
+# For production
+# OSA_FRONTEND__URL=https://your-domain.com
+
+# =============================================================================
+# Database Configuration
+# =============================================================================
+# Default: SQLite in ~/.local/share/osa/osa.db
+
+# PostgreSQL example:
+# OSA_DATABASE__URL=postgresql+asyncpg://user:password@localhost:5432/osa
+
+# Echo SQL queries (for debugging)
+# OSA_DATABASE__ECHO=false
diff --git a/server/Dockerfile b/server/Dockerfile
index 6b0d381..ddd05b9 100644
--- a/server/Dockerfile
+++ b/server/Dockerfile
@@ -3,6 +3,9 @@
# Stage 1: Builder
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder
+# Install curl for healthchecks in dev mode
+RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
+
WORKDIR /app
# Enable bytecode compilation for faster startup
diff --git a/server/migrations/versions/add_auth_tables.py b/server/migrations/versions/add_auth_tables.py
new file mode 100644
index 0000000..b04ac0b
--- /dev/null
+++ b/server/migrations/versions/add_auth_tables.py
@@ -0,0 +1,89 @@
+"""add_auth_tables
+
+Add users, identities, and refresh_tokens tables for authentication.
+
+Revision ID: add_auth_tables
+Revises: add_worker_columns
+Create Date: 2026-02-04
+
+"""
+
+from typing import Sequence, Union
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision: str = "add_auth_tables"
+down_revision: Union[str, Sequence[str], None] = "add_worker_columns"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ """Add authentication tables."""
+ # USERS TABLE
+ op.create_table(
+ "users",
+ sa.Column("id", sa.String(), nullable=False),
+ sa.Column("display_name", sa.String(255), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ )
+
+ # IDENTITIES TABLE
+ op.create_table(
+ "identities",
+ sa.Column("id", sa.String(), nullable=False),
+ sa.Column("user_id", sa.String(), nullable=False),
+ sa.Column("provider", sa.String(50), nullable=False),
+ sa.Column("external_id", sa.String(255), nullable=False),
+ sa.Column("metadata", sa.JSON(), nullable=True),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ondelete="CASCADE",
+ ),
+ sa.UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"),
+ )
+ op.create_index("ix_identities_user_id", "identities", ["user_id"])
+
+ # REFRESH TOKENS TABLE
+ op.create_table(
+ "refresh_tokens",
+ sa.Column("id", sa.String(), nullable=False),
+ sa.Column("user_id", sa.String(), nullable=False),
+ sa.Column("token_hash", sa.String(64), nullable=False),
+ sa.Column("family_id", sa.String(), nullable=False),
+ sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
+ sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.ForeignKeyConstraint(
+ ["user_id"],
+ ["users.id"],
+ ondelete="CASCADE",
+ ),
+ )
+ op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
+ op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
+ op.create_index("ix_refresh_tokens_family_id", "refresh_tokens", ["family_id"])
+
+
+def downgrade() -> None:
+ """Remove authentication tables."""
+ # REFRESH TOKENS
+ op.drop_index("ix_refresh_tokens_family_id", table_name="refresh_tokens")
+ op.drop_index("ix_refresh_tokens_token_hash", table_name="refresh_tokens")
+ op.drop_index("ix_refresh_tokens_user_id", table_name="refresh_tokens")
+ op.drop_table("refresh_tokens")
+
+ # IDENTITIES
+ op.drop_index("ix_identities_user_id", table_name="identities")
+ op.drop_table("identities")
+
+ # USERS
+ op.drop_table("users")
diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py
index f997a6f..735e836 100644
--- a/server/osa/application/api/rest/app.py
+++ b/server/osa/application/api/rest/app.py
@@ -6,7 +6,7 @@
from fastapi.responses import JSONResponse
from osa.application.api.v1.errors import map_osa_error
-from osa.application.api.v1.routes import events, health, records, search, stats, validation
+from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation
from osa.application.di import create_container
from osa.config import Config, configure_logging
from osa.domain.shared.error import OSAError
@@ -32,7 +32,8 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
"""Create FastAPI application."""
- config = Config()
+ # Pydantic Settings populates from env vars at runtime
+ config = Config() # type: ignore[call-arg]
# Configure logging early
configure_logging(config.logging)
@@ -58,6 +59,7 @@ def create_app() -> FastAPI:
# Register v1 routes with /api/v1 prefix
app_instance.include_router(health.router, prefix="/api/v1")
+ app_instance.include_router(auth.router, prefix="/api/v1")
app_instance.include_router(events.router, prefix="/api/v1")
app_instance.include_router(records.router, prefix="/api/v1")
app_instance.include_router(search.router, prefix="/api/v1")
diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py
new file mode 100644
index 0000000..de22b0d
--- /dev/null
+++ b/server/osa/application/api/v1/routes/auth.py
@@ -0,0 +1,277 @@
+"""Authentication routes for OAuth login flow."""
+
+import logging
+from typing import Annotated
+from urllib.parse import urlencode
+
+from dishka import FromDishka
+from dishka.integrations.fastapi import DishkaRoute
+from fastapi import APIRouter, HTTPException, Query, Request, Response
+from fastapi.responses import RedirectResponse
+from pydantic import BaseModel
+
+from osa.config import Config
+from osa.domain.auth.command.login import (
+ CompleteOAuth,
+ CompleteOAuthHandler,
+ InitiateLogin,
+ InitiateLoginHandler,
+)
+from osa.domain.auth.command.token import (
+ Logout,
+ LogoutHandler,
+ RefreshTokens,
+ RefreshTokensHandler,
+)
+from osa.domain.auth.model.value import CurrentUser
+from osa.domain.auth.port.provider_registry import ProviderRegistry
+from osa.domain.auth.service.auth import AuthService
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.error import InvalidStateError
+
+logger = logging.getLogger(__name__)
+
+router = APIRouter(prefix="/auth", tags=["Authentication"], route_class=DishkaRoute)
+
+
+class RefreshTokenRequest(BaseModel):
+ """Request body for token refresh."""
+
+ refresh_token: str
+
+
+class LogoutRequest(BaseModel):
+ """Request body for logout."""
+
+ refresh_token: str
+
+
+class TokenResponse(BaseModel):
+ """Response containing tokens."""
+
+ access_token: str
+ refresh_token: str
+ token_type: str = "Bearer"
+ expires_in: int
+
+
+class LogoutResponse(BaseModel):
+ """Response for logout."""
+
+ success: bool
+
+
+class UserResponse(BaseModel):
+ """Response containing user info."""
+
+ id: str
+ display_name: str | None
+ provider: str
+ external_id: str
+
+
+@router.get("/login")
+async def initiate_login(
+ request: Request,
+ config: FromDishka[Config],
+ handler: FromDishka[InitiateLoginHandler],
+ registry: FromDishka[ProviderRegistry],
+ provider: Annotated[str, Query()],
+ redirect_uri: Annotated[str | None, Query()] = None,
+) -> Response:
+ """Initiate OAuth login flow.
+
+ Redirects to identity provider's authorization page.
+ """
+ # Validate provider is configured
+ if not registry.is_available(provider):
+ available = registry.available_providers()
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "code": "unknown_provider",
+ "message": f"Unknown provider: {provider}. Available: {', '.join(available) or 'none'}",
+ },
+ )
+
+ # Determine callback URL
+ callback_url = config.auth.callback_url
+ if not callback_url:
+ callback_url = str(request.url_for("handle_oauth_callback"))
+
+ # Determine final redirect URI
+ final_redirect = redirect_uri or config.frontend.url
+
+ result = await handler.run(
+ InitiateLogin(
+ callback_url=callback_url,
+ final_redirect_uri=final_redirect,
+ provider=provider,
+ )
+ )
+
+ logger.info("OAuth login initiated for provider=%s, redirecting to IdP", provider)
+ return RedirectResponse(url=result.authorization_url, status_code=302)
+
+
+@router.get("/callback")
+async def handle_oauth_callback(
+ request: Request,
+ config: FromDishka[Config],
+ handler: FromDishka[CompleteOAuthHandler],
+ token_service: FromDishka[TokenService],
+ code: Annotated[str | None, Query()] = None,
+ state: Annotated[str | None, Query()] = None,
+ error: Annotated[str | None, Query()] = None,
+ error_description: Annotated[str | None, Query()] = None,
+) -> Response:
+ """Handle OAuth callback from identity provider.
+
+ Exchanges authorization code for tokens and redirects to frontend.
+ """
+ frontend_url = config.frontend.url
+
+ # Check for OAuth errors
+ if error:
+ logger.warning("OAuth error: %s - %s", error, error_description)
+ error_params = urlencode(
+ {
+ "error": error,
+ "error_description": error_description or "Authentication failed",
+ }
+ )
+ return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}")
+
+ # Validate signed state token
+ if not state:
+ logger.warning("OAuth state missing")
+ error_params = urlencode(
+ {
+ "error": "oauth_state_missing",
+ "error_description": "Missing state parameter",
+ }
+ )
+ return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}")
+
+ state_data = token_service.verify_oauth_state(state)
+ if state_data is None:
+ logger.warning("OAuth state invalid or expired")
+ error_params = urlencode(
+ {
+ "error": "oauth_state_invalid",
+ "error_description": "Invalid or expired state parameter",
+ }
+ )
+ return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}")
+
+ final_redirect, provider = state_data
+
+ if not code:
+ logger.warning("OAuth callback missing code")
+ error_params = urlencode(
+ {
+ "error": "missing_code",
+ "error_description": "Authorization code not provided",
+ }
+ )
+ return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}")
+
+ try:
+ # Determine callback URL (must match what was used in authorization)
+ callback_url = config.auth.callback_url
+ if not callback_url:
+ callback_url = str(request.url_for("handle_oauth_callback"))
+
+ # Complete OAuth flow via handler
+ result = await handler.run(
+ CompleteOAuth(
+ code=code,
+ callback_url=callback_url,
+ provider=provider,
+ )
+ )
+
+ # Build redirect URL with tokens in fragment
+ token_params = urlencode(
+ {
+ "access_token": result.access_token,
+ "refresh_token": result.refresh_token,
+ "token_type": "Bearer",
+ "expires_in": result.expires_in,
+ "user_id": result.user_id,
+ "display_name": result.display_name or "",
+ "provider": result.provider,
+ "external_id": result.external_id,
+ }
+ )
+
+ redirect_url = f"{final_redirect}#auth={token_params}"
+ logger.info(
+ "OAuth complete, user authenticated: user_id=%s, provider=%s", result.user_id, provider
+ )
+ return RedirectResponse(url=redirect_url, status_code=302)
+
+ except Exception as e:
+ logger.exception("OAuth callback failed: %s", e)
+ error_params = urlencode(
+ {
+ "error": "oauth_error",
+ "error_description": "Authentication failed. Please try again.",
+ }
+ )
+ return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}")
+
+
+@router.post("/refresh", response_model=TokenResponse)
+async def refresh_token(
+ body: RefreshTokenRequest,
+ handler: FromDishka[RefreshTokensHandler],
+) -> TokenResponse:
+ """Refresh access token using refresh token."""
+ try:
+ result = await handler.run(RefreshTokens(refresh_token=body.refresh_token))
+ return TokenResponse(
+ access_token=result.access_token,
+ refresh_token=result.refresh_token,
+ expires_in=result.expires_in,
+ )
+ except InvalidStateError as e:
+ raise HTTPException(
+ status_code=401,
+ detail={
+ "code": e.code,
+ "message": e.message,
+ },
+ ) from e
+
+
+@router.post("/logout", response_model=LogoutResponse)
+async def logout(
+ body: LogoutRequest,
+ handler: FromDishka[LogoutHandler],
+) -> LogoutResponse:
+ """Logout and revoke refresh token."""
+ result = await handler.run(Logout(refresh_token=body.refresh_token))
+ return LogoutResponse(success=result.success)
+
+
+@router.get("/me", response_model=UserResponse)
+async def get_me(
+ current_user: FromDishka[CurrentUser],
+ auth_service: FromDishka[AuthService],
+) -> UserResponse:
+ """Get current authenticated user information."""
+ user = await auth_service.get_user_by_id(current_user.user_id)
+
+ if user is None:
+ raise HTTPException(
+ status_code=401,
+ detail={"code": "user_not_found", "message": "User not found"},
+ )
+
+ return UserResponse(
+ id=str(user.id),
+ display_name=user.display_name,
+ provider=current_user.identity.provider,
+ external_id=current_user.identity.external_id,
+ )
diff --git a/server/osa/application/di.py b/server/osa/application/di.py
index 9f70b33..a7b0c7c 100644
--- a/server/osa/application/di.py
+++ b/server/osa/application/di.py
@@ -2,8 +2,10 @@
from osa.cli.util.paths import OSAPaths
from osa.config import Config
+from osa.domain.auth.util.di import AuthProvider
from osa.domain.deposition.util.di import DepositionProvider
from osa.domain.validation.util.di import ValidationProvider
+from osa.infrastructure.auth import AuthInfraProvider
from osa.infrastructure.event.di import EventProvider
from osa.infrastructure.index.di import IndexProvider
from osa.infrastructure.source.di import SourceProvider
@@ -13,7 +15,8 @@
def create_container() -> AsyncContainer:
- config = Config()
+ # Pydantic Settings populates from env vars at runtime
+ config = Config() # type: ignore[call-arg]
# OSAPaths reads OSA_DATA_DIR from environment automatically
paths = OSAPaths()
@@ -26,6 +29,8 @@ def create_container() -> AsyncContainer:
EventProvider(),
DepositionProvider(),
ValidationProvider(),
+ AuthProvider(),
+ AuthInfraProvider(),
context={Config: config, OSAPaths: paths},
scopes=Scope, # type: ignore[arg-type] # Custom scope class
)
diff --git a/server/osa/cli/util/daemon.py b/server/osa/cli/util/daemon.py
index e7d99e2..1ff80fe 100644
--- a/server/osa/cli/util/daemon.py
+++ b/server/osa/cli/util/daemon.py
@@ -123,7 +123,8 @@ def start(
os.environ["OSA_CONFIG_FILE"] = config_file
try:
- app_config = Config()
+ # Pydantic Settings populates from env vars at runtime
+ app_config = Config() # type: ignore[call-arg]
except ValidationError as e:
details = []
for err in e.errors():
diff --git a/server/osa/config.py b/server/osa/config.py
index feebefd..feab7ea 100644
--- a/server/osa/config.py
+++ b/server/osa/config.py
@@ -150,6 +150,51 @@ class WorkerConfig(BaseModel):
batch_size: int = 100 # Maximum events to fetch per poll cycle
+# =============================================================================
+# Authentication Configuration
+# =============================================================================
+
+
+class OrcidConfig(BaseModel):
+ """ORCiD OAuth configuration."""
+
+ client_id: str = ""
+ client_secret: str = ""
+ sandbox: bool = True # Use sandbox.orcid.org by default
+
+ @property
+ def base_url(self) -> str:
+ """Get base URL for ORCiD API based on sandbox setting."""
+ return "https://sandbox.orcid.org" if self.sandbox else "https://orcid.org"
+
+
+class JwtConfig(BaseModel):
+ """JWT configuration."""
+
+ secret: str # Required - set via OSA_AUTH__JWT__SECRET
+ algorithm: str = "HS256"
+ access_token_expire_minutes: int = 60 # 1 hour
+ refresh_token_expire_days: int = 7
+
+ @model_validator(mode="after")
+ def validate_secret_length(self) -> Self:
+ """Ensure JWT secret has sufficient length."""
+ if len(self.secret) < 32:
+ raise ValueError(
+ "JWT secret must be at least 32 characters for security. "
+ "Generate with: openssl rand -hex 32"
+ )
+ return self
+
+
+class AuthConfig(BaseModel):
+ """Authentication configuration."""
+
+ orcid: OrcidConfig = OrcidConfig()
+ jwt: JwtConfig # Required - no default, must be configured via env vars
+ callback_url: str = "" # Full callback URL (e.g., https://myarchive.org/api/v1/auth/callback)
+
+
class Config(BaseSettings):
# These are BaseModel, so env_nested_delimiter handles their env vars
server: Server = Server()
@@ -157,6 +202,7 @@ class Config(BaseSettings):
database: DatabaseConfig = DatabaseConfig()
logging: LoggingConfig = LoggingConfig()
worker: WorkerConfig = WorkerConfig() # Background worker settings
+ auth: AuthConfig # Required - set via OSA_AUTH__JWT__SECRET env var
indexes: list[IndexConfig] = [] # list of index configs
sources: list[SourceConfig] = [] # list of source configs
diff --git a/server/osa/domain/auth/command/__init__.py b/server/osa/domain/auth/command/__init__.py
index e69de29..ddfc4f4 100644
--- a/server/osa/domain/auth/command/__init__.py
+++ b/server/osa/domain/auth/command/__init__.py
@@ -0,0 +1,33 @@
+"""Auth domain commands."""
+
+from .login import (
+ CompleteOAuth,
+ CompleteOAuthHandler,
+ CompleteOAuthResult,
+ InitiateLogin,
+ InitiateLoginHandler,
+ InitiateLoginResult,
+)
+from .token import (
+ Logout,
+ LogoutHandler,
+ LogoutResult,
+ RefreshTokens,
+ RefreshTokensHandler,
+ RefreshTokensResult,
+)
+
+__all__ = [
+ "CompleteOAuth",
+ "CompleteOAuthHandler",
+ "CompleteOAuthResult",
+ "InitiateLogin",
+ "InitiateLoginHandler",
+ "InitiateLoginResult",
+ "Logout",
+ "LogoutHandler",
+ "LogoutResult",
+ "RefreshTokens",
+ "RefreshTokensHandler",
+ "RefreshTokensResult",
+]
diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py
new file mode 100644
index 0000000..a54d423
--- /dev/null
+++ b/server/osa/domain/auth/command/login.py
@@ -0,0 +1,122 @@
+"""Login commands for OAuth authentication flow."""
+
+from dataclasses import dataclass
+from uuid import uuid4
+
+from osa.domain.auth.event import UserAuthenticated
+from osa.domain.auth.port.provider_registry import ProviderRegistry
+from osa.domain.auth.service.auth import AuthService
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.command import Command, CommandHandler, Result
+from osa.domain.shared.error import NotFoundError
+from osa.domain.shared.event import EventId
+from osa.domain.shared.outbox import Outbox
+
+
+class InitiateLogin(Command):
+ """Command to start OAuth login flow."""
+
+ callback_url: str # OAuth callback URL (where IdP redirects after auth)
+ final_redirect_uri: str # Where to redirect user after OAuth completes
+ provider: str
+
+
+class InitiateLoginResult(Result):
+ """Result containing authorization URL."""
+
+ authorization_url: str
+
+
+@dataclass
+class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]):
+ """Handler for InitiateLogin command."""
+
+ provider_registry: ProviderRegistry
+ token_service: TokenService
+
+ async def run(self, cmd: InitiateLogin) -> InitiateLoginResult:
+ """Generate authorization URL for OAuth login."""
+ # Look up the identity provider
+ identity_provider = self.provider_registry.get(cmd.provider)
+ if identity_provider is None:
+ raise NotFoundError(
+ f"Unknown identity provider: {cmd.provider}",
+ code="unknown_provider",
+ )
+
+ # Create signed state token (includes redirect_uri, provider, expiry, and nonce)
+ state = self.token_service.create_oauth_state(cmd.final_redirect_uri, cmd.provider)
+
+ # Get authorization URL from identity provider
+ authorization_url = identity_provider.get_authorization_url(
+ state=state,
+ redirect_uri=cmd.callback_url,
+ )
+
+ return InitiateLoginResult(authorization_url=authorization_url)
+
+
+class CompleteOAuth(Command):
+ """Command to complete OAuth flow with authorization code."""
+
+ code: str
+ callback_url: str # Must match the one used in authorization
+ provider: str # The identity provider name (from verified state)
+
+
+class CompleteOAuthResult(Result):
+ """Result containing user info and tokens."""
+
+ user_id: str
+ display_name: str | None
+ provider: str
+ external_id: str
+ access_token: str
+ refresh_token: str
+ expires_in: int # Seconds until access token expires
+
+
+@dataclass
+class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]):
+ """Handler for CompleteOAuth command."""
+
+ auth_service: AuthService
+ provider_registry: ProviderRegistry
+ token_service: TokenService
+ outbox: Outbox
+
+ async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult:
+ """Exchange authorization code for tokens and create/update user."""
+ # Look up the identity provider
+ identity_provider = self.provider_registry.get(cmd.provider)
+ if identity_provider is None:
+ raise NotFoundError(
+ f"Unknown identity provider: {cmd.provider}",
+ code="unknown_provider",
+ )
+
+ user, identity, access_token, refresh_token = await self.auth_service.complete_oauth(
+ provider=identity_provider,
+ code=cmd.code,
+ redirect_uri=cmd.callback_url,
+ )
+
+ # Emit UserAuthenticated event
+ await self.outbox.append(
+ UserAuthenticated(
+ id=EventId(uuid4()),
+ user_id=str(user.id),
+ provider=identity.provider,
+ external_id=identity.external_id,
+ )
+ )
+
+ return CompleteOAuthResult(
+ user_id=str(user.id),
+ display_name=user.display_name,
+ provider=identity.provider,
+ external_id=identity.external_id,
+ access_token=access_token,
+ refresh_token=refresh_token,
+ expires_in=self.token_service.access_token_expire_seconds,
+ )
diff --git a/server/osa/domain/auth/command/token.py b/server/osa/domain/auth/command/token.py
new file mode 100644
index 0000000..876c800
--- /dev/null
+++ b/server/osa/domain/auth/command/token.py
@@ -0,0 +1,84 @@
+"""Token commands for refresh and logout operations."""
+
+from dataclasses import dataclass
+from uuid import uuid4
+
+from osa.domain.auth.event import UserLoggedOut
+from osa.domain.auth.service.auth import AuthService
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.command import Command, CommandHandler, Result
+from osa.domain.shared.event import EventId
+from osa.domain.shared.outbox import Outbox
+
+
+class RefreshTokens(Command):
+ """Command to refresh access token using refresh token."""
+
+ refresh_token: str
+
+
+class RefreshTokensResult(Result):
+ """Result containing new tokens."""
+
+ access_token: str
+ refresh_token: str
+ expires_in: int
+
+
+@dataclass
+class RefreshTokensHandler(CommandHandler[RefreshTokens, RefreshTokensResult]):
+ """Handler for RefreshTokens command."""
+
+ auth_service: AuthService
+ token_service: TokenService
+
+ async def run(self, cmd: RefreshTokens) -> RefreshTokensResult:
+ """Refresh tokens using refresh token rotation."""
+ _user, access_token, new_refresh_token = await self.auth_service.refresh_tokens(
+ cmd.refresh_token
+ )
+
+ return RefreshTokensResult(
+ access_token=access_token,
+ refresh_token=new_refresh_token,
+ expires_in=self.token_service.access_token_expire_seconds,
+ )
+
+
+class Logout(Command):
+ """Command to logout and revoke refresh token family."""
+
+ refresh_token: str
+
+
+class LogoutResult(Result):
+ """Result for logout operation."""
+
+ success: bool
+
+
+@dataclass
+class LogoutHandler(CommandHandler[Logout, LogoutResult]):
+ """Handler for Logout command."""
+
+ auth_service: AuthService
+ outbox: Outbox
+
+ async def run(self, cmd: Logout) -> LogoutResult:
+ """Logout by revoking refresh token family."""
+ # Get user_id before revoking (for event emission)
+ user_id = await self.auth_service.get_user_id_from_refresh_token(cmd.refresh_token)
+
+ # Revoke tokens
+ success = await self.auth_service.logout(cmd.refresh_token)
+
+ # Emit UserLoggedOut event if we had a valid user
+ if user_id is not None:
+ await self.outbox.append(
+ UserLoggedOut(
+ id=EventId(uuid4()),
+ user_id=str(user_id),
+ )
+ )
+
+ return LogoutResult(success=success)
diff --git a/server/osa/domain/auth/event/__init__.py b/server/osa/domain/auth/event/__init__.py
index e69de29..b5e31c1 100644
--- a/server/osa/domain/auth/event/__init__.py
+++ b/server/osa/domain/auth/event/__init__.py
@@ -0,0 +1,5 @@
+"""Auth domain events."""
+
+from .events import UserAuthenticated, UserLoggedOut
+
+__all__ = ["UserAuthenticated", "UserLoggedOut"]
diff --git a/server/osa/domain/auth/event/events.py b/server/osa/domain/auth/event/events.py
new file mode 100644
index 0000000..63504fe
--- /dev/null
+++ b/server/osa/domain/auth/event/events.py
@@ -0,0 +1,19 @@
+"""Domain events for the auth domain."""
+
+from osa.domain.shared.event import Event, EventId
+
+
+class UserAuthenticated(Event):
+ """Emitted when a user successfully authenticates."""
+
+ id: EventId
+ user_id: str
+ provider: str
+ external_id: str
+
+
+class UserLoggedOut(Event):
+ """Emitted when a user logs out."""
+
+ id: EventId
+ user_id: str
diff --git a/server/osa/domain/auth/model/__init__.py b/server/osa/domain/auth/model/__init__.py
index e69de29..e59c5bf 100644
--- a/server/osa/domain/auth/model/__init__.py
+++ b/server/osa/domain/auth/model/__init__.py
@@ -0,0 +1,17 @@
+"""Auth domain models."""
+
+from .identity import Identity
+from .token import RefreshToken
+from .user import User
+from .value import IdentityId, OrcidId, RefreshTokenId, TokenFamilyId, UserId
+
+__all__ = [
+ "Identity",
+ "IdentityId",
+ "OrcidId",
+ "RefreshToken",
+ "RefreshTokenId",
+ "TokenFamilyId",
+ "User",
+ "UserId",
+]
diff --git a/server/osa/domain/auth/model/identity.py b/server/osa/domain/auth/model/identity.py
new file mode 100644
index 0000000..f86090b
--- /dev/null
+++ b/server/osa/domain/auth/model/identity.py
@@ -0,0 +1,46 @@
+"""Identity entity for the auth domain."""
+
+from datetime import UTC, datetime
+from typing import Any
+
+from osa.domain.auth.model.value import IdentityId, UserId
+from osa.domain.shared.model.entity import Entity
+
+
+class Identity(Entity):
+ """A link between a User and an external identity provider.
+
+ Examples:
+ - ORCiD: provider="orcid", external_id="0000-0001-2345-6789"
+ - SAML: provider="saml:university.edu", external_id="jdoe@university.edu"
+
+ Invariants:
+ - `(provider, external_id)` is globally unique
+ - `user_id` is immutable after creation
+ - `provider` and `external_id` are immutable after creation
+ """
+
+ id: IdentityId
+ user_id: UserId
+ provider: str
+ external_id: str
+ metadata: dict[str, Any] | None = None # Provider-specific data (name, email)
+ created_at: datetime
+
+ @classmethod
+ def create(
+ cls,
+ user_id: UserId,
+ provider: str,
+ external_id: str,
+ metadata: dict[str, Any] | None = None,
+ ) -> "Identity":
+ """Create a new identity link."""
+ return cls(
+ id=IdentityId.generate(),
+ user_id=user_id,
+ provider=provider,
+ external_id=external_id,
+ metadata=metadata,
+ created_at=datetime.now(UTC),
+ )
diff --git a/server/osa/domain/auth/model/token.py b/server/osa/domain/auth/model/token.py
new file mode 100644
index 0000000..409214c
--- /dev/null
+++ b/server/osa/domain/auth/model/token.py
@@ -0,0 +1,68 @@
+"""RefreshToken entity for the auth domain."""
+
+from datetime import UTC, datetime, timedelta
+
+from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId
+from osa.domain.shared.model.entity import Entity
+
+
+class RefreshToken(Entity):
+ """An opaque refresh token for session management.
+
+ Tokens belong to a "family" for theft detection. When a token is refreshed,
+ the new token inherits the family_id. If a revoked token is reused, the
+ entire family is revoked (indicating potential theft).
+
+ Invariants:
+ - `token_hash` is a SHA256 hash (64 hex characters)
+ - `expires_at` is always in the future at creation time
+ - Once `revoked_at` is set, it cannot be unset
+ """
+
+ id: RefreshTokenId
+ user_id: UserId
+ token_hash: str # SHA256 hash of the actual token value
+ family_id: TokenFamilyId
+ expires_at: datetime
+ created_at: datetime
+ revoked_at: datetime | None = None
+
+ @property
+ def is_valid(self) -> bool:
+ """Token is valid if not revoked and not expired."""
+ return self.revoked_at is None and self.expires_at > datetime.now(UTC)
+
+ @property
+ def is_revoked(self) -> bool:
+ """Check if the token has been revoked."""
+ return self.revoked_at is not None
+
+ @property
+ def is_expired(self) -> bool:
+ """Check if the token has expired."""
+ return self.expires_at <= datetime.now(UTC)
+
+ def revoke(self) -> None:
+ """Mark this token as revoked."""
+ if self.revoked_at is None:
+ self.revoked_at = datetime.now(UTC)
+
+ @classmethod
+ def create(
+ cls,
+ user_id: UserId,
+ token_hash: str,
+ family_id: TokenFamilyId,
+ expires_in_days: int = 7,
+ ) -> "RefreshToken":
+ """Create a new refresh token."""
+ now = datetime.now(UTC)
+ return cls(
+ id=RefreshTokenId.generate(),
+ user_id=user_id,
+ token_hash=token_hash,
+ family_id=family_id,
+ expires_at=now + timedelta(days=expires_in_days),
+ created_at=now,
+ revoked_at=None,
+ )
diff --git a/server/osa/domain/auth/model/user.py b/server/osa/domain/auth/model/user.py
new file mode 100644
index 0000000..95c8ca0
--- /dev/null
+++ b/server/osa/domain/auth/model/user.py
@@ -0,0 +1,40 @@
+"""User aggregate for the auth domain."""
+
+from datetime import UTC, datetime
+
+from osa.domain.auth.model.value import UserId
+from osa.domain.shared.model.aggregate import Aggregate
+
+
+class User(Aggregate):
+ """An authenticated user in the OSA system.
+
+ Users are created on first authentication via any identity provider.
+ A user may have multiple linked identities (e.g., ORCiD + institutional SAML).
+
+ Invariants:
+ - `id` is immutable after creation
+ - `created_at` is immutable after creation
+ - `updated_at` is set on any modification
+ """
+
+ id: UserId
+ display_name: str | None
+ created_at: datetime
+ updated_at: datetime | None = None
+
+ @classmethod
+ def create(cls, display_name: str | None = None) -> "User":
+ """Create a new user."""
+ now = datetime.now(UTC)
+ return cls(
+ id=UserId.generate(),
+ display_name=display_name,
+ created_at=now,
+ updated_at=None,
+ )
+
+ def update_display_name(self, display_name: str | None) -> None:
+ """Update the user's display name."""
+ self.display_name = display_name
+ self.updated_at = datetime.now(UTC)
diff --git a/server/osa/domain/auth/model/value.py b/server/osa/domain/auth/model/value.py
new file mode 100644
index 0000000..5c01782
--- /dev/null
+++ b/server/osa/domain/auth/model/value.py
@@ -0,0 +1,111 @@
+"""Value objects for the auth domain."""
+
+import re
+from dataclasses import dataclass
+from uuid import UUID, uuid4
+
+from pydantic import RootModel, field_validator
+
+
+class UserId(RootModel[UUID]):
+ """Unique identifier for a User."""
+
+ @classmethod
+ def generate(cls) -> "UserId":
+ return cls(uuid4())
+
+ def __str__(self) -> str:
+ return str(self.root)
+
+ def __hash__(self) -> int:
+ return hash(self.root)
+
+
+class IdentityId(RootModel[UUID]):
+ """Unique identifier for an Identity."""
+
+ @classmethod
+ def generate(cls) -> "IdentityId":
+ return cls(uuid4())
+
+ def __str__(self) -> str:
+ return str(self.root)
+
+ def __hash__(self) -> int:
+ return hash(self.root)
+
+
+class RefreshTokenId(RootModel[UUID]):
+ """Unique identifier for a RefreshToken."""
+
+ @classmethod
+ def generate(cls) -> "RefreshTokenId":
+ return cls(uuid4())
+
+ def __str__(self) -> str:
+ return str(self.root)
+
+ def __hash__(self) -> int:
+ return hash(self.root)
+
+
+class TokenFamilyId(RootModel[UUID]):
+ """Identifier for a token family.
+
+ All refresh tokens from a single login session share a family_id.
+ Used for theft detection: if a revoked token is reused, the entire
+ family is invalidated.
+ """
+
+ @classmethod
+ def generate(cls) -> "TokenFamilyId":
+ return cls(uuid4())
+
+ def __str__(self) -> str:
+ return str(self.root)
+
+ def __hash__(self) -> int:
+ return hash(self.root)
+
+
+ORCID_PATTERN = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$")
+
+
+@dataclass(frozen=True)
+class ProviderIdentity:
+ """An external identity from an identity provider.
+
+ Encapsulates provider + external_id together since they're always used as a pair.
+ """
+
+ provider: str # e.g., "orcid", "google"
+ external_id: str # Provider-specific user ID
+
+
+@dataclass(frozen=True)
+class CurrentUser:
+ """Authenticated user context extracted from JWT token."""
+
+ user_id: "UserId"
+ identity: ProviderIdentity
+
+
+class OrcidId(RootModel[str]):
+ """An ORCiD identifier (e.g., 0000-0001-2345-6789).
+
+ ORCiD IDs are 16-digit numbers displayed as four groups of four,
+ with a checksum character (digit or X) at the end.
+ """
+
+ @field_validator("root")
+ @classmethod
+ def validate_orcid_format(cls, v: str) -> str:
+ if not ORCID_PATTERN.match(v):
+ raise ValueError(f"Invalid ORCiD format: {v}")
+ return v
+
+ def __str__(self) -> str:
+ return self.root
+
+ def __hash__(self) -> int:
+ return hash(self.root)
diff --git a/server/osa/domain/auth/port/__init__.py b/server/osa/domain/auth/port/__init__.py
index e69de29..9f67041 100644
--- a/server/osa/domain/auth/port/__init__.py
+++ b/server/osa/domain/auth/port/__init__.py
@@ -0,0 +1,12 @@
+"""Auth domain ports."""
+
+from .identity_provider import IdentityInfo, IdentityProvider
+from .repository import IdentityRepository, RefreshTokenRepository, UserRepository
+
+__all__ = [
+ "IdentityInfo",
+ "IdentityProvider",
+ "IdentityRepository",
+ "RefreshTokenRepository",
+ "UserRepository",
+]
diff --git a/server/osa/domain/auth/port/identity_provider.py b/server/osa/domain/auth/port/identity_provider.py
new file mode 100644
index 0000000..24c7305
--- /dev/null
+++ b/server/osa/domain/auth/port/identity_provider.py
@@ -0,0 +1,64 @@
+"""Identity provider port for the auth domain."""
+
+from abc import abstractmethod
+from dataclasses import dataclass
+from typing import Any, Protocol
+
+from osa.domain.shared.port import Port
+
+
+@dataclass(frozen=True)
+class IdentityInfo:
+ """Information returned by an identity provider after successful auth."""
+
+ provider: str # e.g., "orcid", "google", "saml"
+ external_id: str # Provider-specific user ID
+ display_name: str | None
+ email: str | None # May not be available from all providers
+ raw_data: dict[str, Any] # Full provider response for extensibility
+
+
+class IdentityProvider(Port, Protocol):
+ """Port for external identity provider integrations.
+
+ Implementations are adapters in infrastructure/ (e.g., OrcidIdentityProvider).
+ """
+
+ @property
+ @abstractmethod
+ def provider_name(self) -> str:
+ """Unique identifier for this provider (e.g., 'orcid')."""
+ ...
+
+ @abstractmethod
+ def get_authorization_url(self, state: str, redirect_uri: str) -> str:
+ """Generate URL to redirect user for authentication.
+
+ Args:
+ state: CSRF protection token (random, stored in session)
+ redirect_uri: Where the IdP should redirect after auth
+
+ Returns:
+ Full URL to redirect the user to
+ """
+ ...
+
+ @abstractmethod
+ async def exchange_code(
+ self,
+ code: str,
+ redirect_uri: str,
+ ) -> IdentityInfo:
+ """Exchange authorization code for identity information.
+
+ Args:
+ code: Authorization code from IdP callback
+ redirect_uri: Must match the redirect_uri used in authorization URL
+
+ Returns:
+ IdentityInfo with user details from the provider
+
+ Raises:
+ ExternalServiceError: If the IdP request fails
+ """
+ ...
diff --git a/server/osa/domain/auth/port/provider_registry.py b/server/osa/domain/auth/port/provider_registry.py
new file mode 100644
index 0000000..4b718c4
--- /dev/null
+++ b/server/osa/domain/auth/port/provider_registry.py
@@ -0,0 +1,47 @@
+"""Provider registry port for the auth domain."""
+
+from abc import abstractmethod
+from typing import Protocol
+
+from osa.domain.auth.port.identity_provider import IdentityProvider
+from osa.domain.shared.port import Port
+
+
+class ProviderRegistry(Port, Protocol):
+ """Registry of available identity providers.
+
+ Allows looking up identity providers by name and checking
+ which providers are configured/available.
+ """
+
+ @abstractmethod
+ def get(self, provider: str) -> IdentityProvider | None:
+ """Get an identity provider by name.
+
+ Args:
+ provider: The provider name (e.g., "orcid", "google")
+
+ Returns:
+ The identity provider if available, None otherwise
+ """
+ ...
+
+ @abstractmethod
+ def available_providers(self) -> list[str]:
+ """Get list of available provider names.
+
+ Returns:
+ List of provider names that can be used for authentication
+ """
+ ...
+
+ def is_available(self, provider: str) -> bool:
+ """Check if a provider is available.
+
+ Args:
+ provider: The provider name to check
+
+ Returns:
+ True if the provider is available
+ """
+ return provider in self.available_providers()
diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py
new file mode 100644
index 0000000..a2ca4a0
--- /dev/null
+++ b/server/osa/domain/auth/port/repository.py
@@ -0,0 +1,88 @@
+"""Repository ports for the auth domain."""
+
+from abc import abstractmethod
+from typing import Protocol
+
+from osa.domain.auth.model.identity import Identity
+from osa.domain.auth.model.token import RefreshToken
+from osa.domain.auth.model.user import User
+from osa.domain.auth.model.value import (
+ IdentityId,
+ RefreshTokenId,
+ TokenFamilyId,
+ UserId,
+)
+from osa.domain.shared.port import Port
+
+
+class UserRepository(Port, Protocol):
+ """Repository for User aggregate persistence."""
+
+ @abstractmethod
+ async def get(self, user_id: UserId) -> User | None:
+ """Get a user by ID."""
+ ...
+
+ @abstractmethod
+ async def save(self, user: User) -> None:
+ """Save a user (create or update)."""
+ ...
+
+
+class IdentityRepository(Port, Protocol):
+ """Repository for Identity entity persistence."""
+
+ @abstractmethod
+ async def get(self, identity_id: IdentityId) -> Identity | None:
+ """Get an identity by ID."""
+ ...
+
+ @abstractmethod
+ async def get_by_provider_and_external_id(
+ self, provider: str, external_id: str
+ ) -> Identity | None:
+ """Get an identity by provider and external ID."""
+ ...
+
+ @abstractmethod
+ async def get_by_user_id(self, user_id: UserId) -> list[Identity]:
+ """Get all identities for a user."""
+ ...
+
+ @abstractmethod
+ async def save(self, identity: Identity) -> None:
+ """Save an identity."""
+ ...
+
+
+class RefreshTokenRepository(Port, Protocol):
+ """Repository for RefreshToken entity persistence."""
+
+ @abstractmethod
+ async def get(self, token_id: RefreshTokenId) -> RefreshToken | None:
+ """Get a refresh token by ID."""
+ ...
+
+ @abstractmethod
+ async def get_by_token_hash(
+ self, token_hash: str, *, for_update: bool = False
+ ) -> RefreshToken | None:
+ """Get a refresh token by its hash.
+
+ Args:
+ token_hash: The hash of the token to find.
+ for_update: If True, acquire a row-level lock (SELECT FOR UPDATE)
+ to prevent concurrent modifications. Use this when the token
+ will be modified after retrieval (e.g., during refresh).
+ """
+ ...
+
+ @abstractmethod
+ async def save(self, token: RefreshToken) -> None:
+ """Save a refresh token."""
+ ...
+
+ @abstractmethod
+ async def revoke_family(self, family_id: TokenFamilyId) -> int:
+ """Revoke all tokens in a family. Returns count of revoked tokens."""
+ ...
diff --git a/server/osa/domain/auth/service/__init__.py b/server/osa/domain/auth/service/__init__.py
index e69de29..0e01a6e 100644
--- a/server/osa/domain/auth/service/__init__.py
+++ b/server/osa/domain/auth/service/__init__.py
@@ -0,0 +1,6 @@
+"""Auth domain services."""
+
+from .auth import AuthService
+from .token import TokenService
+
+__all__ = ["AuthService", "TokenService"]
diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py
new file mode 100644
index 0000000..cf1c25b
--- /dev/null
+++ b/server/osa/domain/auth/service/auth.py
@@ -0,0 +1,276 @@
+"""Auth service for orchestrating authentication flows."""
+
+import logging
+
+from osa.domain.auth.model.identity import Identity
+from osa.domain.auth.model.token import RefreshToken
+from osa.domain.auth.model.user import User
+from osa.domain.auth.model.value import ProviderIdentity, TokenFamilyId, UserId
+from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider
+from osa.domain.auth.port.repository import (
+ IdentityRepository,
+ RefreshTokenRepository,
+ UserRepository,
+)
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.outbox import Outbox
+from osa.domain.shared.service import Service
+
+logger = logging.getLogger(__name__)
+
+
+class AuthService(Service):
+ """Orchestrates authentication flows.
+
+ - initiate_login: Generate authorization URL
+ - complete_oauth: Exchange code for tokens, create/update user
+ - refresh_tokens: Issue new tokens from refresh token
+ - logout: Revoke refresh token family
+ """
+
+ _user_repo: UserRepository
+ _identity_repo: IdentityRepository
+ _refresh_token_repo: RefreshTokenRepository
+ _token_service: TokenService
+ _outbox: Outbox
+
+ async def initiate_login(
+ self,
+ provider: IdentityProvider,
+ state: str,
+ redirect_uri: str,
+ ) -> str:
+ """Generate the authorization URL for OAuth login.
+
+ Args:
+ provider: The identity provider to use
+ state: CSRF protection token (caller should store this)
+ redirect_uri: Where the IdP should redirect after auth
+
+ Returns:
+ Authorization URL to redirect the user to
+ """
+ return provider.get_authorization_url(state, redirect_uri)
+
+ async def complete_oauth(
+ self,
+ provider: IdentityProvider,
+ code: str,
+ redirect_uri: str,
+ ) -> tuple[User, Identity, str, str]:
+ """Complete OAuth flow and issue tokens.
+
+ Args:
+ provider: The identity provider
+ code: Authorization code from callback
+ redirect_uri: Must match the one used in authorization
+
+ Returns:
+ Tuple of (user, identity, access_token, refresh_token)
+ """
+ # Exchange code for identity info
+ identity_info = await provider.exchange_code(code, redirect_uri)
+
+ # Find or create user and identity
+ user, identity = await self._find_or_create_user(identity_info)
+
+ # Create tokens
+ access_token, refresh_token = await self._create_tokens(user, identity)
+
+ logger.info(
+ "User authenticated: user_id=%s, provider=%s, external_id=%s",
+ user.id,
+ identity.provider,
+ identity.external_id,
+ )
+
+ return user, identity, access_token, refresh_token
+
+ async def refresh_tokens(
+ self,
+ refresh_token_raw: str,
+ ) -> tuple[User, str, str]:
+ """Refresh access token using refresh token.
+
+ Implements token rotation: old refresh token is revoked,
+ new one issued in same family.
+
+ Args:
+ refresh_token_raw: The raw refresh token from client
+
+ Returns:
+ Tuple of (user, new_access_token, new_refresh_token)
+
+ Raises:
+ InvalidStateError: If refresh token is invalid, expired, or revoked
+ """
+ from osa.domain.shared.error import InvalidStateError
+
+ token_hash = self._token_service.hash_token(refresh_token_raw)
+ # Lock the row to prevent concurrent refresh attempts (race condition)
+ stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash, for_update=True)
+
+ if stored_token is None:
+ raise InvalidStateError("Invalid refresh token", code="invalid_refresh_token")
+
+ if stored_token.is_revoked:
+ # Potential theft detected - revoke entire family
+ await self._refresh_token_repo.revoke_family(stored_token.family_id)
+ logger.warning(
+ "Refresh token reuse detected, family revoked: family_id=%s",
+ stored_token.family_id,
+ )
+ raise InvalidStateError(
+ "Token family revoked - please login again",
+ code="token_family_revoked",
+ )
+
+ if stored_token.is_expired:
+ raise InvalidStateError("Refresh token expired", code="refresh_token_expired")
+
+ # Revoke old token
+ stored_token.revoke()
+ await self._refresh_token_repo.save(stored_token)
+
+ # Get user and their primary identity
+ user = await self._user_repo.get(stored_token.user_id)
+ if user is None:
+ raise InvalidStateError("User not found", code="user_not_found")
+
+ primary_identity = await self.get_primary_identity(user.id)
+
+ if primary_identity is None:
+ raise InvalidStateError("User has no identity", code="no_identity")
+
+ # Issue new tokens in same family
+ raw_token, token_hash = self._token_service.create_refresh_token()
+ new_refresh_token = RefreshToken.create(
+ user_id=user.id,
+ token_hash=token_hash,
+ family_id=stored_token.family_id,
+ expires_in_days=self._token_service.refresh_token_expire_days,
+ )
+ await self._refresh_token_repo.save(new_refresh_token)
+
+ access_token = self._token_service.create_access_token(
+ user_id=user.id,
+ identity=primary_identity,
+ )
+
+ logger.info("Tokens refreshed: user_id=%s", user.id)
+
+ return user, access_token, raw_token
+
+ async def logout(self, refresh_token_raw: str) -> bool:
+ """Logout by revoking refresh token family.
+
+ Args:
+ refresh_token_raw: The raw refresh token
+
+ Returns:
+ True if tokens were revoked
+ """
+ token_hash = self._token_service.hash_token(refresh_token_raw)
+ stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash)
+
+ if stored_token is None:
+ # Token not found, but logout succeeds anyway
+ return True
+
+ revoked_count = await self._refresh_token_repo.revoke_family(stored_token.family_id)
+ logger.info(
+ "User logged out: user_id=%s, revoked_tokens=%d",
+ stored_token.user_id,
+ revoked_count,
+ )
+
+ return True
+
+ async def get_user_by_id(self, user_id: UserId) -> User | None:
+ """Get a user by their ID."""
+ return await self._user_repo.get(user_id)
+
+ async def get_primary_identity(self, user_id: UserId) -> ProviderIdentity | None:
+ """Get the primary identity for a user.
+
+ Returns the first identity found for the user. In the future,
+ this could be extended to support multiple identities with a
+ designated primary.
+ """
+ identities = await self._identity_repo.get_by_user_id(user_id)
+ if not identities:
+ return None
+ first = identities[0]
+ return ProviderIdentity(provider=first.provider, external_id=first.external_id)
+
+ async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None:
+ """Get the user ID associated with a refresh token.
+
+ Args:
+ raw_token: The raw refresh token string
+
+ Returns:
+ The user ID if token exists, None otherwise
+ """
+ token_hash = self._token_service.hash_token(raw_token)
+ stored = await self._refresh_token_repo.get_by_token_hash(token_hash)
+ return stored.user_id if stored else None
+
+ async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, Identity]:
+ """Find existing user by identity or create new one."""
+ # Check if identity already exists
+ existing_identity = await self._identity_repo.get_by_provider_and_external_id(
+ identity_info.provider, identity_info.external_id
+ )
+
+ if existing_identity:
+ # User exists, return them
+ user = await self._user_repo.get(existing_identity.user_id)
+ if user is None:
+ # Orphaned identity - shouldn't happen with CASCADE
+ raise RuntimeError(f"Identity exists without user: {existing_identity.id}")
+ return user, existing_identity
+
+ # Create new user and identity
+ user = User.create(display_name=identity_info.display_name)
+ await self._user_repo.save(user)
+
+ identity = Identity.create(
+ user_id=user.id,
+ provider=identity_info.provider,
+ external_id=identity_info.external_id,
+ metadata=identity_info.raw_data,
+ )
+ await self._identity_repo.save(identity)
+
+ logger.info(
+ "New user created: user_id=%s, provider=%s",
+ user.id,
+ identity_info.provider,
+ )
+
+ return user, identity
+
+ async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str]:
+ """Create access and refresh tokens for a user."""
+ # Create refresh token
+ raw_token, token_hash = self._token_service.create_refresh_token()
+ refresh_token = RefreshToken.create(
+ user_id=user.id,
+ token_hash=token_hash,
+ family_id=TokenFamilyId.generate(),
+ expires_in_days=self._token_service.refresh_token_expire_days,
+ )
+ await self._refresh_token_repo.save(refresh_token)
+
+ # Create access token
+ provider_identity = ProviderIdentity(
+ provider=identity.provider,
+ external_id=identity.external_id,
+ )
+ access_token = self._token_service.create_access_token(
+ user_id=user.id,
+ identity=provider_identity,
+ )
+
+ return access_token, raw_token
diff --git a/server/osa/domain/auth/service/token.py b/server/osa/domain/auth/service/token.py
new file mode 100644
index 0000000..f8948bd
--- /dev/null
+++ b/server/osa/domain/auth/service/token.py
@@ -0,0 +1,197 @@
+"""Token service for JWT creation and validation."""
+
+import hashlib
+import hmac
+import json
+import logging
+import secrets
+import time
+from base64 import urlsafe_b64decode, urlsafe_b64encode
+from datetime import UTC, datetime, timedelta
+from typing import Any
+
+import jwt
+
+from osa.config import JwtConfig
+from osa.domain.auth.model.value import ProviderIdentity, UserId
+from osa.domain.shared.service import Service
+
+logger = logging.getLogger(__name__)
+
+# OAuth state validity period (5 minutes)
+STATE_EXPIRY_SECONDS = 300
+
+
+class TokenService(Service):
+ """Service for JWT access token and refresh token operations.
+
+ - Access tokens are JWTs (HS256) with user claims
+ - Refresh tokens are opaque random strings, stored as hashes in the database
+ - OAuth state tokens are signed payloads for CSRF protection
+ """
+
+ _config: JwtConfig
+
+ def create_access_token(
+ self,
+ user_id: UserId,
+ identity: ProviderIdentity,
+ additional_claims: dict[str, Any] | None = None,
+ ) -> str:
+ """Create a JWT access token.
+
+ Args:
+ user_id: The user's internal ID
+ identity: The user's external identity (provider + external_id)
+ additional_claims: Optional extra claims to include
+
+ Returns:
+ Encoded JWT string
+ """
+ now = datetime.now(UTC)
+ expires_at = now + timedelta(minutes=self._config.access_token_expire_minutes)
+
+ payload = {
+ "sub": str(user_id),
+ "provider": identity.provider,
+ "external_id": identity.external_id,
+ "aud": "authenticated",
+ "iat": int(now.timestamp()),
+ "exp": int(expires_at.timestamp()),
+ "jti": secrets.token_hex(16),
+ }
+
+ if additional_claims:
+ payload.update(additional_claims)
+
+ return jwt.encode(
+ payload,
+ self._config.secret,
+ algorithm=self._config.algorithm,
+ )
+
+ def validate_access_token(self, token: str) -> dict[str, Any]:
+ """Validate and decode a JWT access token.
+
+ Args:
+ token: The JWT string to validate
+
+ Returns:
+ Decoded payload dict
+
+ Raises:
+ jwt.InvalidTokenError: If token is invalid or expired
+ """
+ return jwt.decode(
+ token,
+ self._config.secret,
+ algorithms=[self._config.algorithm],
+ audience="authenticated",
+ )
+
+ def create_refresh_token(self) -> tuple[str, str]:
+ """Create a new refresh token.
+
+ Returns:
+ Tuple of (raw_token, token_hash)
+ - raw_token: Send to client
+ - token_hash: Store in database
+ """
+ raw_token = secrets.token_urlsafe(32)
+ token_hash = self.hash_token(raw_token)
+ return raw_token, token_hash
+
+ @staticmethod
+ def hash_token(raw_token: str) -> str:
+ """Create SHA256 hash of a token.
+
+ Args:
+ raw_token: The raw token string
+
+ Returns:
+ Hex-encoded SHA256 hash (64 characters)
+ """
+ return hashlib.sha256(raw_token.encode()).hexdigest()
+
+ @property
+ def access_token_expire_seconds(self) -> int:
+ """Get access token expiry in seconds."""
+ return self._config.access_token_expire_minutes * 60
+
+ @property
+ def refresh_token_expire_days(self) -> int:
+ """Get refresh token expiry in days."""
+ return self._config.refresh_token_expire_days
+
+ def create_oauth_state(self, redirect_uri: str, provider: str) -> str:
+ """Create a signed, self-verifying OAuth state token.
+
+ The state contains: nonce, redirect_uri, provider, expiry timestamp.
+ Signed with HMAC-SHA256 using the JWT secret.
+
+ Args:
+ redirect_uri: The URI to redirect to after OAuth completes
+ provider: The identity provider name (e.g., "orcid")
+
+ Returns:
+ URL-safe signed state token in format: payload.signature
+ """
+ payload = {
+ "nonce": secrets.token_urlsafe(16),
+ "redirect_uri": redirect_uri,
+ "provider": provider,
+ "exp": int(time.time()) + STATE_EXPIRY_SECONDS,
+ }
+ payload_bytes = json.dumps(payload, separators=(",", ":")).encode()
+ payload_b64 = urlsafe_b64encode(payload_bytes).rstrip(b"=").decode()
+
+ signature = hmac.new(self._config.secret.encode(), payload_bytes, hashlib.sha256).digest()
+ signature_b64 = urlsafe_b64encode(signature).rstrip(b"=").decode()
+
+ return f"{payload_b64}.{signature_b64}"
+
+ def verify_oauth_state(self, state: str) -> tuple[str, str] | None:
+ """Verify a signed state token and return the redirect_uri and provider if valid.
+
+ Args:
+ state: The signed state token to verify
+
+ Returns:
+ Tuple of (redirect_uri, provider) if valid, None if invalid or expired
+ """
+ try:
+ parts = state.split(".")
+ if len(parts) != 2:
+ return None
+
+ payload_b64, signature_b64 = parts
+
+ # Restore base64 padding
+ payload_bytes = urlsafe_b64decode(payload_b64 + "==")
+ signature = urlsafe_b64decode(signature_b64 + "==")
+
+ # Verify signature
+ expected_sig = hmac.new(
+ self._config.secret.encode(), payload_bytes, hashlib.sha256
+ ).digest()
+ if not hmac.compare_digest(signature, expected_sig):
+ logger.warning("OAuth state signature verification failed")
+ return None
+
+ # Parse and check expiry
+ payload = json.loads(payload_bytes)
+ if payload.get("exp", 0) < time.time():
+ logger.warning("OAuth state expired")
+ return None
+
+ redirect_uri = payload.get("redirect_uri")
+ provider = payload.get("provider")
+ if not redirect_uri or not provider:
+ logger.warning("OAuth state missing redirect_uri or provider")
+ return None
+
+ return redirect_uri, provider
+
+ except Exception as e:
+ logger.warning("OAuth state verification error: %s", e)
+ return None
diff --git a/server/osa/domain/auth/adapter/__init__.py b/server/osa/domain/auth/util/__init__.py
similarity index 100%
rename from server/osa/domain/auth/adapter/__init__.py
rename to server/osa/domain/auth/util/__init__.py
diff --git a/server/osa/domain/auth/util/di/__init__.py b/server/osa/domain/auth/util/di/__init__.py
new file mode 100644
index 0000000..22bb012
--- /dev/null
+++ b/server/osa/domain/auth/util/di/__init__.py
@@ -0,0 +1,5 @@
+"""DI providers for auth domain."""
+
+from .provider import AuthProvider
+
+__all__ = ["AuthProvider"]
diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py
new file mode 100644
index 0000000..8cce9b2
--- /dev/null
+++ b/server/osa/domain/auth/util/di/provider.py
@@ -0,0 +1,104 @@
+"""DI provider for auth domain."""
+
+from uuid import UUID
+
+import jwt
+from dishka import from_context, provide
+from fastapi import HTTPException
+from starlette.requests import Request
+
+from osa.config import Config
+from osa.domain.auth.command.login import (
+ CompleteOAuthHandler,
+ InitiateLoginHandler,
+)
+from osa.domain.auth.command.token import LogoutHandler, RefreshTokensHandler
+from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId
+from osa.domain.auth.port.repository import (
+ IdentityRepository,
+ RefreshTokenRepository,
+ UserRepository,
+)
+from osa.domain.auth.service.auth import AuthService
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.outbox import Outbox
+from osa.util.di.base import Provider
+from osa.util.di.scope import Scope
+
+
+class AuthProvider(Provider):
+ """DI provider for auth domain services and handlers."""
+
+ request = from_context(provides=Request, scope=Scope.UOW)
+
+ # Command Handlers
+ initiate_login_handler = provide(InitiateLoginHandler, scope=Scope.UOW)
+ complete_oauth_handler = provide(CompleteOAuthHandler, scope=Scope.UOW)
+ refresh_tokens_handler = provide(RefreshTokensHandler, scope=Scope.UOW)
+ logout_handler = provide(LogoutHandler, scope=Scope.UOW)
+
+ @provide(scope=Scope.UOW)
+ def get_token_service(self, config: Config) -> TokenService:
+ """Provide TokenService."""
+ return TokenService(_config=config.auth.jwt)
+
+ @provide(scope=Scope.UOW)
+ def get_auth_service(
+ self,
+ user_repo: UserRepository,
+ identity_repo: IdentityRepository,
+ refresh_token_repo: RefreshTokenRepository,
+ token_service: TokenService,
+ outbox: Outbox,
+ ) -> AuthService:
+ """Provide AuthService."""
+ return AuthService(
+ _user_repo=user_repo,
+ _identity_repo=identity_repo,
+ _refresh_token_repo=refresh_token_repo,
+ _token_service=token_service,
+ _outbox=outbox,
+ )
+
+ @provide(scope=Scope.UOW)
+ def get_current_user(
+ self,
+ request: Request,
+ token_service: TokenService,
+ ) -> CurrentUser:
+ """Extract and validate CurrentUser from JWT in Authorization header.
+
+ Raises:
+ HTTPException: If token is missing, expired, or invalid
+ """
+ auth_header = request.headers.get("Authorization")
+ if not auth_header or not auth_header.startswith("Bearer "):
+ raise HTTPException(
+ status_code=401,
+ detail={"code": "missing_token", "message": "Authorization header required"},
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ token = auth_header[7:] # Remove "Bearer " prefix
+
+ try:
+ payload = token_service.validate_access_token(token)
+ return CurrentUser(
+ user_id=UserId(UUID(payload["sub"])),
+ identity=ProviderIdentity(
+ provider=payload["provider"],
+ external_id=payload["external_id"],
+ ),
+ )
+ except jwt.ExpiredSignatureError as e:
+ raise HTTPException(
+ status_code=401,
+ detail={"code": "token_expired", "message": "Token has expired"},
+ headers={"WWW-Authenticate": "Bearer"},
+ ) from e
+ except jwt.InvalidTokenError as e:
+ raise HTTPException(
+ status_code=401,
+ detail={"code": "invalid_token", "message": "Invalid token"},
+ headers={"WWW-Authenticate": "Bearer"},
+ ) from e
diff --git a/server/osa/infrastructure/auth/__init__.py b/server/osa/infrastructure/auth/__init__.py
new file mode 100644
index 0000000..fc36309
--- /dev/null
+++ b/server/osa/infrastructure/auth/__init__.py
@@ -0,0 +1,5 @@
+"""Auth infrastructure adapters."""
+
+from .di import AuthInfraProvider
+
+__all__ = ["AuthInfraProvider"]
diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py
new file mode 100644
index 0000000..40a354e
--- /dev/null
+++ b/server/osa/infrastructure/auth/di.py
@@ -0,0 +1,71 @@
+"""DI provider for auth infrastructure."""
+
+import httpx
+from dishka import provide
+
+from osa.config import Config
+from osa.domain.auth.port.identity_provider import IdentityProvider
+from osa.domain.auth.port.provider_registry import ProviderRegistry
+from osa.domain.auth.port.repository import (
+ IdentityRepository,
+ RefreshTokenRepository,
+ UserRepository,
+)
+from osa.infrastructure.auth.orcid import OrcidIdentityProvider
+from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry
+from osa.infrastructure.persistence.repository.auth import (
+ PostgresIdentityRepository,
+ PostgresRefreshTokenRepository,
+ PostgresUserRepository,
+)
+from osa.util.di.base import Provider
+from osa.util.di.scope import Scope
+
+# HTTP client timeout configuration
+_HTTP_TIMEOUT = httpx.Timeout(
+ connect=5.0, # Connection timeout
+ read=10.0, # Read timeout
+ write=5.0, # Write timeout
+ pool=5.0, # Pool timeout
+)
+
+
+class AuthInfraProvider(Provider):
+ """DI provider for auth infrastructure adapters."""
+
+ # Repository adapters
+ user_repo = provide(
+ PostgresUserRepository,
+ scope=Scope.UOW,
+ provides=UserRepository,
+ )
+ identity_repo = provide(
+ PostgresIdentityRepository,
+ scope=Scope.UOW,
+ provides=IdentityRepository,
+ )
+ refresh_token_repo = provide(
+ PostgresRefreshTokenRepository,
+ scope=Scope.UOW,
+ provides=RefreshTokenRepository,
+ )
+
+ @provide(scope=Scope.APP)
+ def get_auth_http_client(self) -> httpx.AsyncClient:
+ """Shared HTTP client for auth operations (connection pooling)."""
+ return httpx.AsyncClient(timeout=_HTTP_TIMEOUT)
+
+ @provide(scope=Scope.APP)
+ def get_provider_registry(
+ self, config: Config, http_client: httpx.AsyncClient
+ ) -> ProviderRegistry:
+ """Provide ProviderRegistry with configured identity providers."""
+ providers: dict[str, IdentityProvider] = {}
+
+ # Register ORCID if configured
+ if config.auth.orcid.client_id:
+ providers["orcid"] = OrcidIdentityProvider(
+ config=config.auth.orcid, http_client=http_client
+ )
+
+ return InMemoryProviderRegistry(providers)
diff --git a/server/osa/infrastructure/auth/orcid.py b/server/osa/infrastructure/auth/orcid.py
new file mode 100644
index 0000000..a5db378
--- /dev/null
+++ b/server/osa/infrastructure/auth/orcid.py
@@ -0,0 +1,101 @@
+"""ORCiD identity provider adapter."""
+
+import logging
+from urllib.parse import urlencode
+
+import httpx
+
+from osa.config import OrcidConfig
+from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider
+from osa.domain.shared.error import ExternalServiceError
+
+logger = logging.getLogger(__name__)
+
+
+class OrcidIdentityProvider(IdentityProvider):
+ """IdentityProvider implementation for ORCiD OAuth."""
+
+ def __init__(self, config: OrcidConfig, http_client: httpx.AsyncClient) -> None:
+ self._config = config
+ self._http = http_client
+
+ @property
+ def provider_name(self) -> str:
+ return "orcid"
+
+ def get_authorization_url(self, state: str, redirect_uri: str) -> str:
+ """Generate ORCiD authorization URL."""
+ params = {
+ "client_id": self._config.client_id,
+ "response_type": "code",
+ "scope": "/authenticate",
+ "redirect_uri": redirect_uri,
+ "state": state,
+ }
+ return f"{self._config.base_url}/oauth/authorize?{urlencode(params)}"
+
+ async def exchange_code(
+ self,
+ code: str,
+ redirect_uri: str,
+ ) -> IdentityInfo:
+ """Exchange authorization code for identity information."""
+ token_url = f"{self._config.base_url}/oauth/token"
+
+ data = {
+ "client_id": self._config.client_id,
+ "client_secret": self._config.client_secret,
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": redirect_uri,
+ }
+
+ try:
+ response = await self._http.post(
+ token_url,
+ data=data,
+ headers={"Accept": "application/json"},
+ )
+
+ if response.status_code != 200:
+ logger.error(
+ "ORCiD token exchange failed: status=%d, body=%s",
+ response.status_code,
+ response.text,
+ )
+ raise ExternalServiceError(
+ f"ORCiD token exchange failed: {response.status_code}",
+ code="idp_unavailable",
+ )
+
+ token_data = response.json()
+
+ except httpx.RequestError as e:
+ logger.exception("ORCiD request failed: %s", e)
+ raise ExternalServiceError(
+ "Failed to connect to ORCiD",
+ code="idp_unavailable",
+ ) from e
+
+ # ORCiD returns user info directly in token response
+ # {
+ # "access_token": "...",
+ # "token_type": "bearer",
+ # "scope": "/authenticate",
+ # "name": "Jane Doe",
+ # "orcid": "0000-0001-2345-6789"
+ # }
+ orcid_id = token_data.get("orcid")
+ if not orcid_id:
+ raise ExternalServiceError(
+ "ORCiD response missing orcid field",
+ code="oauth_error",
+ )
+
+ return IdentityInfo(
+ provider="orcid",
+ external_id=orcid_id,
+ display_name=token_data.get("name"),
+ email=None, # ORCiD doesn't return email in basic auth
+ raw_data=token_data,
+ )
diff --git a/server/osa/infrastructure/auth/provider_registry.py b/server/osa/infrastructure/auth/provider_registry.py
new file mode 100644
index 0000000..0dde031
--- /dev/null
+++ b/server/osa/infrastructure/auth/provider_registry.py
@@ -0,0 +1,37 @@
+"""Provider registry implementation."""
+
+from osa.domain.auth.port.identity_provider import IdentityProvider
+from osa.domain.auth.port.provider_registry import ProviderRegistry
+
+
+class InMemoryProviderRegistry(ProviderRegistry):
+ """In-memory provider registry.
+
+ Stores a mapping of provider names to their implementations.
+ Providers are registered at application startup via DI.
+ """
+
+ def __init__(self, providers: dict[str, IdentityProvider] | None = None) -> None:
+ """Initialize registry with optional initial providers.
+
+ Args:
+ providers: Optional dict mapping provider names to implementations
+ """
+ self._providers: dict[str, IdentityProvider] = providers or {}
+
+ def get(self, provider: str) -> IdentityProvider | None:
+ """Get an identity provider by name."""
+ return self._providers.get(provider)
+
+ def available_providers(self) -> list[str]:
+ """Get list of available provider names."""
+ return list(self._providers.keys())
+
+ def register(self, name: str, provider: IdentityProvider) -> None:
+ """Register a provider.
+
+ Args:
+ name: The provider name
+ provider: The provider implementation
+ """
+ self._providers[name] = provider
diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py
new file mode 100644
index 0000000..ff46522
--- /dev/null
+++ b/server/osa/infrastructure/persistence/repository/auth.py
@@ -0,0 +1,222 @@
+"""PostgreSQL repository implementations for auth domain."""
+
+from datetime import UTC, datetime
+from uuid import UUID
+
+from sqlalchemy import insert, select, update
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from osa.domain.auth.model.identity import Identity
+from osa.domain.auth.model.token import RefreshToken
+from osa.domain.auth.model.user import User
+from osa.domain.auth.model.value import (
+ IdentityId,
+ RefreshTokenId,
+ TokenFamilyId,
+ UserId,
+)
+from osa.domain.auth.port.repository import (
+ IdentityRepository,
+ RefreshTokenRepository,
+ UserRepository,
+)
+from osa.infrastructure.persistence.tables import (
+ identities_table,
+ refresh_tokens_table,
+ users_table,
+)
+
+
+def _row_to_user(row: dict) -> User:
+ """Convert a database row to a User model."""
+ return User(
+ id=UserId(UUID(row["id"])),
+ display_name=row["display_name"],
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ )
+
+
+def _user_to_dict(user: User) -> dict:
+ """Convert a User model to a database row dict."""
+ return {
+ "id": str(user.id),
+ "display_name": user.display_name,
+ "created_at": user.created_at,
+ "updated_at": user.updated_at,
+ }
+
+
+def _row_to_identity(row: dict) -> Identity:
+ """Convert a database row to an Identity model."""
+ return Identity(
+ id=IdentityId(UUID(row["id"])),
+ user_id=UserId(UUID(row["user_id"])),
+ provider=row["provider"],
+ external_id=row["external_id"],
+ metadata=row["metadata"],
+ created_at=row["created_at"],
+ )
+
+
+def _identity_to_dict(identity: Identity) -> dict:
+ """Convert an Identity model to a database row dict."""
+ return {
+ "id": str(identity.id),
+ "user_id": str(identity.user_id),
+ "provider": identity.provider,
+ "external_id": identity.external_id,
+ "metadata": identity.metadata,
+ "created_at": identity.created_at,
+ }
+
+
+def _row_to_refresh_token(row: dict) -> RefreshToken:
+ """Convert a database row to a RefreshToken model."""
+ return RefreshToken(
+ id=RefreshTokenId(UUID(row["id"])),
+ user_id=UserId(UUID(row["user_id"])),
+ token_hash=row["token_hash"],
+ family_id=TokenFamilyId(UUID(row["family_id"])),
+ expires_at=row["expires_at"],
+ created_at=row["created_at"],
+ revoked_at=row["revoked_at"],
+ )
+
+
+def _refresh_token_to_dict(token: RefreshToken) -> dict:
+ """Convert a RefreshToken model to a database row dict."""
+ return {
+ "id": str(token.id),
+ "user_id": str(token.user_id),
+ "token_hash": token.token_hash,
+ "family_id": str(token.family_id),
+ "expires_at": token.expires_at,
+ "created_at": token.created_at,
+ "revoked_at": token.revoked_at,
+ }
+
+
+class PostgresUserRepository(UserRepository):
+ """PostgreSQL implementation of UserRepository."""
+
+ def __init__(self, session: AsyncSession) -> None:
+ self.session = session
+
+ async def get(self, user_id: UserId) -> User | None:
+ stmt = select(users_table).where(users_table.c.id == str(user_id))
+ result = await self.session.execute(stmt)
+ row = result.mappings().first()
+ return _row_to_user(dict(row)) if row else None
+
+ async def save(self, user: User) -> None:
+ user_dict = _user_to_dict(user)
+ existing = await self.get(user.id)
+
+ if existing:
+ stmt = update(users_table).where(users_table.c.id == str(user.id)).values(**user_dict)
+ else:
+ stmt = insert(users_table).values(**user_dict)
+
+ await self.session.execute(stmt)
+ await self.session.flush()
+
+
+class PostgresIdentityRepository(IdentityRepository):
+ """PostgreSQL implementation of IdentityRepository."""
+
+ def __init__(self, session: AsyncSession) -> None:
+ self.session = session
+
+ async def get(self, identity_id: IdentityId) -> Identity | None:
+ stmt = select(identities_table).where(identities_table.c.id == str(identity_id))
+ result = await self.session.execute(stmt)
+ row = result.mappings().first()
+ return _row_to_identity(dict(row)) if row else None
+
+ async def get_by_provider_and_external_id(
+ self, provider: str, external_id: str
+ ) -> Identity | None:
+ stmt = select(identities_table).where(
+ identities_table.c.provider == provider,
+ identities_table.c.external_id == external_id,
+ )
+ result = await self.session.execute(stmt)
+ row = result.mappings().first()
+ return _row_to_identity(dict(row)) if row else None
+
+ async def get_by_user_id(self, user_id: UserId) -> list[Identity]:
+ stmt = select(identities_table).where(identities_table.c.user_id == str(user_id))
+ result = await self.session.execute(stmt)
+ rows = result.mappings().all()
+ return [_row_to_identity(dict(row)) for row in rows]
+
+ async def save(self, identity: Identity) -> None:
+ identity_dict = _identity_to_dict(identity)
+ existing = await self.get(identity.id)
+
+ if existing:
+ stmt = (
+ update(identities_table)
+ .where(identities_table.c.id == str(identity.id))
+ .values(**identity_dict)
+ )
+ else:
+ stmt = insert(identities_table).values(**identity_dict)
+
+ await self.session.execute(stmt)
+ await self.session.flush()
+
+
+class PostgresRefreshTokenRepository(RefreshTokenRepository):
+ """PostgreSQL implementation of RefreshTokenRepository."""
+
+ def __init__(self, session: AsyncSession) -> None:
+ self.session = session
+
+ async def get(self, token_id: RefreshTokenId) -> RefreshToken | None:
+ stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.id == str(token_id))
+ result = await self.session.execute(stmt)
+ row = result.mappings().first()
+ return _row_to_refresh_token(dict(row)) if row else None
+
+ async def get_by_token_hash(
+ self, token_hash: str, *, for_update: bool = False
+ ) -> RefreshToken | None:
+ stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.token_hash == token_hash)
+ if for_update:
+ stmt = stmt.with_for_update()
+ result = await self.session.execute(stmt)
+ row = result.mappings().first()
+ return _row_to_refresh_token(dict(row)) if row else None
+
+ async def save(self, token: RefreshToken) -> None:
+ token_dict = _refresh_token_to_dict(token)
+ existing = await self.get(token.id)
+
+ if existing:
+ stmt = (
+ update(refresh_tokens_table)
+ .where(refresh_tokens_table.c.id == str(token.id))
+ .values(**token_dict)
+ )
+ else:
+ stmt = insert(refresh_tokens_table).values(**token_dict)
+
+ await self.session.execute(stmt)
+ await self.session.flush()
+
+ async def revoke_family(self, family_id: TokenFamilyId) -> int:
+ """Revoke all tokens in a family. Returns count of revoked tokens."""
+ now = datetime.now(UTC)
+ stmt = (
+ update(refresh_tokens_table)
+ .where(
+ refresh_tokens_table.c.family_id == str(family_id),
+ refresh_tokens_table.c.revoked_at.is_(None),
+ )
+ .values(revoked_at=now)
+ )
+ result = await self.session.execute(stmt)
+ await self.session.flush()
+ return result.rowcount
diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py
index 6b66044..f22ee64 100644
--- a/server/osa/infrastructure/persistence/tables.py
+++ b/server/osa/infrastructure/persistence/tables.py
@@ -3,12 +3,14 @@
from sqlalchemy import (
Column,
DateTime,
+ ForeignKey,
Index,
Integer,
MetaData,
String,
Table,
Text,
+ UniqueConstraint,
text,
)
from sqlalchemy.types import JSON
@@ -122,3 +124,54 @@
events_table.c.created_at,
postgresql_where=text("delivery_status = 'failed'"),
)
+
+
+# ============================================================================
+# USERS TABLE (Authentication)
+# ============================================================================
+users_table = Table(
+ "users",
+ metadata,
+ Column("id", String, primary_key=True), # UUID as string
+ Column("display_name", String(255), nullable=True),
+ Column("created_at", DateTime(timezone=True), nullable=False),
+ Column("updated_at", DateTime(timezone=True), nullable=True),
+)
+
+
+# ============================================================================
+# IDENTITIES TABLE (Authentication)
+# ============================================================================
+identities_table = Table(
+ "identities",
+ metadata,
+ Column("id", String, primary_key=True), # UUID as string
+ Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
+ Column("provider", String(50), nullable=False), # "orcid", "google", etc.
+ Column("external_id", String(255), nullable=False), # ORCiD ID, Google ID, etc.
+ Column("metadata", JSON, nullable=True), # Provider-specific data (name, email)
+ Column("created_at", DateTime(timezone=True), nullable=False),
+ UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"),
+)
+
+Index("ix_identities_user_id", identities_table.c.user_id)
+
+
+# ============================================================================
+# REFRESH TOKENS TABLE (Authentication)
+# ============================================================================
+refresh_tokens_table = Table(
+ "refresh_tokens",
+ metadata,
+ Column("id", String, primary_key=True), # UUID as string
+ Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
+ Column("token_hash", String(64), nullable=False), # SHA256 hash
+ Column("family_id", String, nullable=False), # UUID - for theft detection
+ Column("expires_at", DateTime(timezone=True), nullable=False),
+ Column("created_at", DateTime(timezone=True), nullable=False),
+ Column("revoked_at", DateTime(timezone=True), nullable=True),
+)
+
+Index("ix_refresh_tokens_user_id", refresh_tokens_table.c.user_id)
+Index("ix_refresh_tokens_token_hash", refresh_tokens_table.c.token_hash)
+Index("ix_refresh_tokens_family_id", refresh_tokens_table.c.family_id)
diff --git a/server/pyproject.toml b/server/pyproject.toml
index be6bb29..09609ca 100644
--- a/server/pyproject.toml
+++ b/server/pyproject.toml
@@ -28,6 +28,7 @@ dependencies = [
"rich>=14.2.0",
"asyncpg>=0.31.0",
"psycopg2-binary>=2.9.11",
+ "pyjwt>=2.11.0",
]
[project.scripts]
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index e69de29..cf9c15f 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -0,0 +1,7 @@
+"""Global test fixtures."""
+
+import os
+
+# Set JWT secret before any test modules import Config
+# This must happen at module load time, not in a fixture
+os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-unit-tests-min-32")
diff --git a/server/tests/unit/application/api/v1/routes/test_auth_state.py b/server/tests/unit/application/api/v1/routes/test_auth_state.py
new file mode 100644
index 0000000..ebe63ca
--- /dev/null
+++ b/server/tests/unit/application/api/v1/routes/test_auth_state.py
@@ -0,0 +1,157 @@
+"""Unit tests for OAuth state token signing/verification."""
+
+import time
+
+import pytest
+
+from osa.config import JwtConfig
+from osa.domain.auth.service.token import STATE_EXPIRY_SECONDS, TokenService
+
+
+@pytest.fixture
+def token_service() -> TokenService:
+ """Create a TokenService with test config."""
+ config = JwtConfig(
+ secret="test-secret-key-for-signing-min-32",
+ algorithm="HS256",
+ access_token_expire_minutes=15,
+ refresh_token_expire_days=7,
+ )
+ return TokenService(_config=config)
+
+
+@pytest.fixture
+def token_service_alt_secret() -> TokenService:
+ """Create a TokenService with a different secret."""
+ config = JwtConfig(
+ secret="different-secret-key-for-testing",
+ algorithm="HS256",
+ access_token_expire_minutes=15,
+ refresh_token_expire_days=7,
+ )
+ return TokenService(_config=config)
+
+
+class TestSignedStateCreation:
+ """Tests for TokenService.create_oauth_state."""
+
+ def test_creates_state_with_redirect_uri(self, token_service: TokenService):
+ """Should create a signed state containing the redirect URI and provider."""
+ redirect_uri = "https://example.com/callback"
+ provider = "orcid"
+
+ state = token_service.create_oauth_state(redirect_uri, provider)
+
+ # State should be format: payload.signature
+ assert "." in state
+ parts = state.split(".")
+ assert len(parts) == 2
+
+ def test_different_nonces_produce_different_states(self, token_service: TokenService):
+ """Each state should have a unique nonce."""
+ redirect_uri = "https://example.com"
+
+ state1 = token_service.create_oauth_state(redirect_uri, "orcid")
+ state2 = token_service.create_oauth_state(redirect_uri, "orcid")
+
+ assert state1 != state2
+
+ def test_state_is_url_safe(self, token_service: TokenService):
+ """State should only contain URL-safe characters."""
+ redirect_uri = "https://example.com/path?query=value"
+
+ state = token_service.create_oauth_state(redirect_uri, "orcid")
+
+ # URL-safe base64 uses only these characters
+ allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.")
+ assert all(c in allowed for c in state)
+
+
+class TestSignedStateVerification:
+ """Tests for TokenService.verify_oauth_state."""
+
+ def test_verifies_valid_state(self, token_service: TokenService):
+ """Should return (redirect_uri, provider) for valid state."""
+ redirect_uri = "https://example.com/after-login"
+ provider = "orcid"
+
+ state = token_service.create_oauth_state(redirect_uri, provider)
+ result = token_service.verify_oauth_state(state)
+
+ assert result is not None
+ assert result == (redirect_uri, provider)
+
+ def test_rejects_tampered_payload(self, token_service: TokenService):
+ """Should reject state with tampered payload."""
+ state = token_service.create_oauth_state("https://example.com", "orcid")
+
+ # Tamper with the payload (change a character)
+ parts = state.split(".")
+ tampered_payload = "x" + parts[0][1:]
+ tampered_state = f"{tampered_payload}.{parts[1]}"
+
+ result = token_service.verify_oauth_state(tampered_state)
+ assert result is None
+
+ def test_rejects_tampered_signature(self, token_service: TokenService):
+ """Should reject state with tampered signature."""
+ state = token_service.create_oauth_state("https://example.com", "orcid")
+
+ # Tamper with the signature
+ parts = state.split(".")
+ tampered_sig = "x" + parts[1][1:]
+ tampered_state = f"{parts[0]}.{tampered_sig}"
+
+ result = token_service.verify_oauth_state(tampered_state)
+ assert result is None
+
+ def test_rejects_wrong_secret(
+ self, token_service: TokenService, token_service_alt_secret: TokenService
+ ):
+ """Should reject state signed with different secret."""
+ state = token_service.create_oauth_state("https://example.com", "orcid")
+
+ result = token_service_alt_secret.verify_oauth_state(state)
+ assert result is None
+
+ def test_rejects_expired_state(self, token_service: TokenService, monkeypatch):
+ """Should reject expired state."""
+ state = token_service.create_oauth_state("https://example.com", "orcid")
+
+ # Fast-forward time past expiry
+ future_time = time.time() + STATE_EXPIRY_SECONDS + 1
+ monkeypatch.setattr(time, "time", lambda: future_time)
+
+ result = token_service.verify_oauth_state(state)
+ assert result is None
+
+ def test_rejects_malformed_state(self, token_service: TokenService):
+ """Should reject malformed state strings."""
+ # No dot separator
+ assert token_service.verify_oauth_state("nodot") is None
+
+ # Empty parts
+ assert token_service.verify_oauth_state(".") is None
+ assert token_service.verify_oauth_state("payload.") is None
+ assert token_service.verify_oauth_state(".signature") is None
+
+ # Too many parts
+ assert token_service.verify_oauth_state("a.b.c") is None
+
+ # Invalid base64
+ assert token_service.verify_oauth_state("!!!.???") is None
+
+ def test_rejects_empty_state(self, token_service: TokenService):
+ """Should reject empty state."""
+ assert token_service.verify_oauth_state("") is None
+
+ def test_handles_special_characters_in_redirect_uri(self, token_service: TokenService):
+ """Should handle redirect URIs with special characters."""
+ redirect_uri = "https://example.com/path?foo=bar&baz=qux#fragment"
+ provider = "orcid"
+
+ state = token_service.create_oauth_state(redirect_uri, provider)
+ result = token_service.verify_oauth_state(state)
+
+ assert result is not None
+ assert result == (redirect_uri, provider)
diff --git a/server/tests/unit/domain/auth/__init__.py b/server/tests/unit/domain/auth/__init__.py
new file mode 100644
index 0000000..0392972
--- /dev/null
+++ b/server/tests/unit/domain/auth/__init__.py
@@ -0,0 +1 @@
+"""Unit tests for auth domain."""
diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py
new file mode 100644
index 0000000..bef7f98
--- /dev/null
+++ b/server/tests/unit/domain/auth/test_auth_service.py
@@ -0,0 +1,393 @@
+"""Unit tests for AuthService."""
+
+from datetime import UTC, datetime, timedelta
+from unittest.mock import AsyncMock, MagicMock
+from uuid import uuid4
+
+import pytest
+
+from osa.config import JwtConfig
+from osa.domain.auth.model.identity import Identity
+from osa.domain.auth.model.token import RefreshToken
+from osa.domain.auth.model.user import User
+from osa.domain.auth.model.value import IdentityId, RefreshTokenId, TokenFamilyId, UserId
+from osa.domain.auth.port.identity_provider import IdentityInfo
+from osa.domain.auth.service.auth import AuthService
+from osa.domain.auth.service.token import TokenService
+from osa.domain.shared.error import InvalidStateError
+
+
+def make_auth_service(
+ user_repo: AsyncMock | None = None,
+ identity_repo: AsyncMock | None = None,
+ refresh_token_repo: AsyncMock | None = None,
+ token_service: TokenService | None = None,
+ outbox: AsyncMock | None = None,
+) -> AuthService:
+ """Create an AuthService with mocked dependencies."""
+ if user_repo is None:
+ user_repo = AsyncMock()
+ if identity_repo is None:
+ identity_repo = AsyncMock()
+ if refresh_token_repo is None:
+ refresh_token_repo = AsyncMock()
+ if token_service is None:
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=60,
+ refresh_token_expire_days=7,
+ )
+ token_service = TokenService(_config=config)
+ if outbox is None:
+ outbox = AsyncMock()
+
+ return AuthService(
+ _user_repo=user_repo,
+ _identity_repo=identity_repo,
+ _refresh_token_repo=refresh_token_repo,
+ _token_service=token_service,
+ _outbox=outbox,
+ )
+
+
+def make_identity_provider(identity_info: IdentityInfo | None = None) -> MagicMock:
+ """Create a mock identity provider."""
+ provider = MagicMock()
+ provider.provider_name = "orcid"
+
+ if identity_info is None:
+ identity_info = IdentityInfo(
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ display_name="Jane Doe",
+ email=None,
+ raw_data={"name": "Jane Doe", "orcid": "0000-0001-2345-6789"},
+ )
+
+ provider.exchange_code = AsyncMock(return_value=identity_info)
+ provider.get_authorization_url = MagicMock(return_value="https://orcid.org/oauth/authorize?...")
+ return provider
+
+
+class TestAuthServiceInitiateLogin:
+ """Tests for AuthService.initiate_login."""
+
+ @pytest.mark.asyncio
+ async def test_initiate_login_returns_authorization_url(self):
+ """initiate_login should return the provider's authorization URL."""
+ service = make_auth_service()
+ provider = make_identity_provider()
+ provider.get_authorization_url.return_value = "https://orcid.org/oauth?state=abc"
+
+ url = await service.initiate_login(
+ provider=provider,
+ state="test-state",
+ redirect_uri="http://localhost/callback",
+ )
+
+ assert url == "https://orcid.org/oauth?state=abc"
+ provider.get_authorization_url.assert_called_once_with(
+ "test-state", "http://localhost/callback"
+ )
+
+
+class TestAuthServiceCompleteOAuth:
+ """Tests for AuthService.complete_oauth."""
+
+ @pytest.mark.asyncio
+ async def test_complete_oauth_creates_new_user(self):
+ """complete_oauth should create user and identity for new user."""
+ user_repo = AsyncMock()
+ user_repo.get.return_value = None # No existing user
+
+ identity_repo = AsyncMock()
+ identity_repo.get_by_provider_and_external_id.return_value = None # No existing identity
+
+ refresh_token_repo = AsyncMock()
+
+ service = make_auth_service(
+ user_repo=user_repo,
+ identity_repo=identity_repo,
+ refresh_token_repo=refresh_token_repo,
+ )
+ provider = make_identity_provider()
+
+ user, identity, access_token, refresh_token = await service.complete_oauth(
+ provider=provider,
+ code="auth-code",
+ redirect_uri="http://localhost/callback",
+ )
+
+ # Should create user and identity
+ user_repo.save.assert_called_once()
+ identity_repo.save.assert_called_once()
+ refresh_token_repo.save.assert_called_once()
+
+ # Should return valid data
+ assert user.display_name == "Jane Doe"
+ assert identity.provider == "orcid"
+ assert identity.external_id == "0000-0001-2345-6789"
+ assert isinstance(access_token, str)
+ assert isinstance(refresh_token, str)
+
+ @pytest.mark.asyncio
+ async def test_complete_oauth_returns_existing_user(self):
+ """complete_oauth should return existing user if identity exists."""
+ existing_user = User(
+ id=UserId(uuid4()),
+ display_name="Existing User",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+ existing_identity = Identity(
+ id=IdentityId(uuid4()),
+ user_id=existing_user.id,
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ metadata=None,
+ created_at=datetime.now(UTC),
+ )
+
+ user_repo = AsyncMock()
+ user_repo.get.return_value = existing_user
+
+ identity_repo = AsyncMock()
+ identity_repo.get_by_provider_and_external_id.return_value = existing_identity
+
+ refresh_token_repo = AsyncMock()
+
+ service = make_auth_service(
+ user_repo=user_repo,
+ identity_repo=identity_repo,
+ refresh_token_repo=refresh_token_repo,
+ )
+ provider = make_identity_provider()
+
+ user, identity, _, _ = await service.complete_oauth(
+ provider=provider,
+ code="auth-code",
+ redirect_uri="http://localhost/callback",
+ )
+
+ # Should NOT create new user/identity
+ user_repo.save.assert_not_called()
+ identity_repo.save.assert_not_called()
+
+ # Should return existing user
+ assert user.id == existing_user.id
+ assert identity.id == existing_identity.id
+
+
+class TestAuthServiceRefreshTokens:
+ """Tests for AuthService.refresh_tokens."""
+
+ @pytest.mark.asyncio
+ async def test_refresh_tokens_issues_new_tokens(self):
+ """refresh_tokens should issue new access and refresh tokens."""
+ user = User(
+ id=UserId(uuid4()),
+ display_name="Test User",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+ identity = Identity(
+ id=IdentityId(uuid4()),
+ user_id=user.id,
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ metadata=None,
+ created_at=datetime.now(UTC),
+ )
+ old_token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=user.id,
+ token_hash="old-hash",
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=None,
+ )
+
+ user_repo = AsyncMock()
+ user_repo.get.return_value = user
+
+ identity_repo = AsyncMock()
+ identity_repo.get_by_user_id.return_value = [identity]
+
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = old_token
+
+ service = make_auth_service(
+ user_repo=user_repo,
+ identity_repo=identity_repo,
+ refresh_token_repo=refresh_token_repo,
+ )
+
+ returned_user, access_token, new_refresh_token = await service.refresh_tokens(
+ "raw-refresh-token"
+ )
+
+ # Should save new refresh token
+ assert refresh_token_repo.save.call_count == 2 # Once for revoking old, once for new
+
+ # Should return valid data
+ assert returned_user.id == user.id
+ assert isinstance(access_token, str)
+ assert isinstance(new_refresh_token, str)
+
+ @pytest.mark.asyncio
+ async def test_refresh_tokens_revokes_old_token(self):
+ """refresh_tokens should revoke the old refresh token."""
+ user = User(
+ id=UserId(uuid4()),
+ display_name="Test User",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+ identity = Identity(
+ id=IdentityId(uuid4()),
+ user_id=user.id,
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ metadata=None,
+ created_at=datetime.now(UTC),
+ )
+ old_token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=user.id,
+ token_hash="old-hash",
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=None,
+ )
+
+ user_repo = AsyncMock()
+ user_repo.get.return_value = user
+
+ identity_repo = AsyncMock()
+ identity_repo.get_by_user_id.return_value = [identity]
+
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = old_token
+
+ service = make_auth_service(
+ user_repo=user_repo,
+ identity_repo=identity_repo,
+ refresh_token_repo=refresh_token_repo,
+ )
+
+ await service.refresh_tokens("raw-refresh-token")
+
+ # The old token should be revoked
+ assert old_token.is_revoked is True
+
+ @pytest.mark.asyncio
+ async def test_refresh_tokens_rejects_invalid_token(self):
+ """refresh_tokens should raise for unknown refresh token."""
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = None
+
+ service = make_auth_service(refresh_token_repo=refresh_token_repo)
+
+ with pytest.raises(InvalidStateError) as exc_info:
+ await service.refresh_tokens("invalid-token")
+
+ assert exc_info.value.code == "invalid_refresh_token"
+
+ @pytest.mark.asyncio
+ async def test_refresh_tokens_detects_reuse_and_revokes_family(self):
+ """refresh_tokens should revoke entire family if revoked token is reused."""
+ user_id = UserId(uuid4())
+ family_id = TokenFamilyId(uuid4())
+
+ # Token that was already revoked (potential theft)
+ revoked_token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=user_id,
+ token_hash="revoked-hash",
+ family_id=family_id,
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=datetime.now(UTC) - timedelta(hours=1), # Already revoked
+ )
+
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = revoked_token
+ refresh_token_repo.revoke_family.return_value = 3 # 3 tokens revoked
+
+ service = make_auth_service(refresh_token_repo=refresh_token_repo)
+
+ with pytest.raises(InvalidStateError) as exc_info:
+ await service.refresh_tokens("stolen-token")
+
+ assert exc_info.value.code == "token_family_revoked"
+
+ # Should revoke entire family
+ refresh_token_repo.revoke_family.assert_called_once_with(family_id)
+
+ @pytest.mark.asyncio
+ async def test_refresh_tokens_rejects_expired_token(self):
+ """refresh_tokens should raise for expired refresh token."""
+ expired_token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="expired-hash",
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired
+ created_at=datetime.now(UTC) - timedelta(days=8),
+ revoked_at=None,
+ )
+
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = expired_token
+
+ service = make_auth_service(refresh_token_repo=refresh_token_repo)
+
+ with pytest.raises(InvalidStateError) as exc_info:
+ await service.refresh_tokens("expired-token")
+
+ assert exc_info.value.code == "refresh_token_expired"
+
+
+class TestAuthServiceLogout:
+ """Tests for AuthService.logout."""
+
+ @pytest.mark.asyncio
+ async def test_logout_revokes_token_family(self):
+ """logout should revoke the entire token family."""
+ family_id = TokenFamilyId(uuid4())
+ token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="token-hash",
+ family_id=family_id,
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=None,
+ )
+
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = token
+ refresh_token_repo.revoke_family.return_value = 1
+
+ service = make_auth_service(refresh_token_repo=refresh_token_repo)
+
+ result = await service.logout("raw-refresh-token")
+
+ assert result is True
+ refresh_token_repo.revoke_family.assert_called_once_with(family_id)
+
+ @pytest.mark.asyncio
+ async def test_logout_succeeds_for_unknown_token(self):
+ """logout should succeed even if token is not found."""
+ refresh_token_repo = AsyncMock()
+ refresh_token_repo.get_by_token_hash.return_value = None
+
+ service = make_auth_service(refresh_token_repo=refresh_token_repo)
+
+ result = await service.logout("unknown-token")
+
+ assert result is True
+ refresh_token_repo.revoke_family.assert_not_called()
diff --git a/server/tests/unit/domain/auth/test_command_handlers.py b/server/tests/unit/domain/auth/test_command_handlers.py
new file mode 100644
index 0000000..9179e15
--- /dev/null
+++ b/server/tests/unit/domain/auth/test_command_handlers.py
@@ -0,0 +1,332 @@
+"""Unit tests for auth command handlers."""
+
+from datetime import UTC, datetime
+from unittest.mock import AsyncMock, MagicMock
+from uuid import uuid4
+
+import pytest
+
+from osa.config import JwtConfig
+from osa.domain.auth.command.login import (
+ CompleteOAuth,
+ CompleteOAuthHandler,
+ InitiateLogin,
+ InitiateLoginHandler,
+)
+from osa.domain.auth.command.token import (
+ Logout,
+ LogoutHandler,
+ RefreshTokens,
+ RefreshTokensHandler,
+)
+from osa.domain.auth.event import UserAuthenticated, UserLoggedOut
+from osa.domain.auth.model.identity import Identity
+from osa.domain.auth.model.user import User
+from osa.domain.auth.model.value import IdentityId, UserId
+from osa.domain.auth.service.token import TokenService
+
+
+def make_token_service() -> TokenService:
+ """Create a TokenService with test config."""
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=60,
+ refresh_token_expire_days=7,
+ )
+ return TokenService(_config=config)
+
+
+def make_identity_provider() -> MagicMock:
+ """Create a mock identity provider."""
+ provider = MagicMock()
+ provider.provider_name = "orcid"
+ provider.get_authorization_url = MagicMock(
+ return_value="https://orcid.org/oauth/authorize?state=xyz"
+ )
+ return provider
+
+
+def make_provider_registry(identity_provider: MagicMock | None = None) -> MagicMock:
+ """Create a mock provider registry."""
+ if identity_provider is None:
+ identity_provider = make_identity_provider()
+ registry = MagicMock()
+ registry.get.return_value = identity_provider
+ registry.is_available.return_value = True
+ registry.available_providers.return_value = ["orcid"]
+ return registry
+
+
+class TestInitiateLoginHandler:
+ """Tests for InitiateLoginHandler."""
+
+ @pytest.mark.asyncio
+ async def test_run_returns_authorization_url(self):
+ """Handler should return authorization URL from identity provider."""
+ identity_provider = make_identity_provider()
+ provider_registry = make_provider_registry(identity_provider)
+ token_service = make_token_service()
+
+ handler = InitiateLoginHandler(
+ provider_registry=provider_registry,
+ token_service=token_service,
+ )
+
+ result = await handler.run(
+ InitiateLogin(
+ callback_url="http://localhost/callback",
+ final_redirect_uri="http://localhost/dashboard",
+ provider="orcid",
+ )
+ )
+
+ assert result.authorization_url == "https://orcid.org/oauth/authorize?state=xyz"
+ identity_provider.get_authorization_url.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_run_creates_signed_state(self):
+ """Handler should create signed state token with final redirect URI and provider."""
+ identity_provider = make_identity_provider()
+ provider_registry = make_provider_registry(identity_provider)
+ token_service = make_token_service()
+
+ handler = InitiateLoginHandler(
+ provider_registry=provider_registry,
+ token_service=token_service,
+ )
+
+ await handler.run(
+ InitiateLogin(
+ callback_url="http://localhost/callback",
+ final_redirect_uri="http://localhost/dashboard",
+ provider="orcid",
+ )
+ )
+
+ # Verify state was passed to identity provider
+ call_args = identity_provider.get_authorization_url.call_args
+ state = call_args[1]["state"] if "state" in call_args[1] else call_args[0][0]
+
+ # Verify state can be decoded to get back the redirect URI and provider
+ result = token_service.verify_oauth_state(state)
+ assert result is not None
+ redirect_uri, provider = result
+ assert redirect_uri == "http://localhost/dashboard"
+ assert provider == "orcid"
+
+
+class TestCompleteOAuthHandler:
+ """Tests for CompleteOAuthHandler."""
+
+ @pytest.mark.asyncio
+ async def test_run_emits_user_authenticated_event(self):
+ """Handler should emit UserAuthenticated event on successful OAuth."""
+ user = User(
+ id=UserId(uuid4()),
+ display_name="Jane Doe",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+ identity = Identity(
+ id=IdentityId(uuid4()),
+ user_id=user.id,
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ metadata=None,
+ created_at=datetime.now(UTC),
+ )
+
+ auth_service = AsyncMock()
+ auth_service.complete_oauth.return_value = (
+ user,
+ identity,
+ "access-token",
+ "refresh-token",
+ )
+
+ provider_registry = make_provider_registry()
+ token_service = make_token_service()
+ outbox = AsyncMock()
+
+ handler = CompleteOAuthHandler(
+ auth_service=auth_service,
+ provider_registry=provider_registry,
+ token_service=token_service,
+ outbox=outbox,
+ )
+
+ await handler.run(
+ CompleteOAuth(
+ code="auth-code",
+ callback_url="http://localhost/callback",
+ provider="orcid",
+ )
+ )
+
+ # Verify UserAuthenticated event was emitted
+ outbox.append.assert_called_once()
+ event = outbox.append.call_args[0][0]
+ assert isinstance(event, UserAuthenticated)
+ assert event.user_id == str(user.id)
+ assert event.provider == "orcid"
+ assert event.external_id == "0000-0001-2345-6789"
+
+ @pytest.mark.asyncio
+ async def test_run_returns_user_info_and_tokens(self):
+ """Handler should return user info and tokens."""
+ user = User(
+ id=UserId(uuid4()),
+ display_name="Jane Doe",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+ identity = Identity(
+ id=IdentityId(uuid4()),
+ user_id=user.id,
+ provider="orcid",
+ external_id="0000-0001-2345-6789",
+ metadata=None,
+ created_at=datetime.now(UTC),
+ )
+
+ auth_service = AsyncMock()
+ auth_service.complete_oauth.return_value = (
+ user,
+ identity,
+ "access-token",
+ "refresh-token",
+ )
+
+ provider_registry = make_provider_registry()
+ token_service = make_token_service()
+ outbox = AsyncMock()
+
+ handler = CompleteOAuthHandler(
+ auth_service=auth_service,
+ provider_registry=provider_registry,
+ token_service=token_service,
+ outbox=outbox,
+ )
+
+ result = await handler.run(
+ CompleteOAuth(
+ code="auth-code",
+ callback_url="http://localhost/callback",
+ provider="orcid",
+ )
+ )
+
+ assert result.user_id == str(user.id)
+ assert result.display_name == "Jane Doe"
+ assert result.provider == "orcid"
+ assert result.external_id == "0000-0001-2345-6789"
+ assert result.access_token == "access-token"
+ assert result.refresh_token == "refresh-token"
+ assert result.expires_in == 60 * 60 # 60 minutes in seconds
+
+
+class TestRefreshTokensHandler:
+ """Tests for RefreshTokensHandler."""
+
+ @pytest.mark.asyncio
+ async def test_run_returns_new_tokens(self):
+ """Handler should return new tokens from auth service."""
+ user = User(
+ id=UserId(uuid4()),
+ display_name="Test User",
+ created_at=datetime.now(UTC),
+ updated_at=None,
+ )
+
+ auth_service = AsyncMock()
+ auth_service.refresh_tokens.return_value = (
+ user,
+ "new-access-token",
+ "new-refresh-token",
+ )
+
+ token_service = make_token_service()
+
+ handler = RefreshTokensHandler(
+ auth_service=auth_service,
+ token_service=token_service,
+ )
+
+ result = await handler.run(RefreshTokens(refresh_token="old-refresh-token"))
+
+ assert result.access_token == "new-access-token"
+ assert result.refresh_token == "new-refresh-token"
+ assert result.expires_in == 60 * 60 # 60 minutes in seconds
+
+ auth_service.refresh_tokens.assert_called_once_with("old-refresh-token")
+
+
+class TestLogoutHandler:
+ """Tests for LogoutHandler."""
+
+ @pytest.mark.asyncio
+ async def test_run_emits_user_logged_out_event(self):
+ """Handler should emit UserLoggedOut event when user has valid token."""
+ user_id = UserId(uuid4())
+
+ auth_service = AsyncMock()
+ auth_service.get_user_id_from_refresh_token.return_value = user_id
+ auth_service.logout.return_value = True
+
+ outbox = AsyncMock()
+
+ handler = LogoutHandler(
+ auth_service=auth_service,
+ outbox=outbox,
+ )
+
+ result = await handler.run(Logout(refresh_token="refresh-token"))
+
+ assert result.success is True
+
+ # Verify UserLoggedOut event was emitted
+ outbox.append.assert_called_once()
+ event = outbox.append.call_args[0][0]
+ assert isinstance(event, UserLoggedOut)
+ assert event.user_id == str(user_id)
+
+ @pytest.mark.asyncio
+ async def test_run_does_not_emit_event_for_unknown_token(self):
+ """Handler should not emit event if token is unknown."""
+ auth_service = AsyncMock()
+ auth_service.get_user_id_from_refresh_token.return_value = None
+ auth_service.logout.return_value = True
+
+ outbox = AsyncMock()
+
+ handler = LogoutHandler(
+ auth_service=auth_service,
+ outbox=outbox,
+ )
+
+ result = await handler.run(Logout(refresh_token="unknown-token"))
+
+ assert result.success is True
+
+ # Should NOT emit event for unknown token
+ outbox.append.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_run_returns_success(self):
+ """Handler should return success status from auth service."""
+ auth_service = AsyncMock()
+ auth_service.get_user_id_from_refresh_token.return_value = UserId(uuid4())
+ auth_service.logout.return_value = True
+
+ outbox = AsyncMock()
+
+ handler = LogoutHandler(
+ auth_service=auth_service,
+ outbox=outbox,
+ )
+
+ result = await handler.run(Logout(refresh_token="refresh-token"))
+
+ assert result.success is True
+ auth_service.logout.assert_called_once_with("refresh-token")
diff --git a/server/tests/unit/domain/auth/test_refresh_token.py b/server/tests/unit/domain/auth/test_refresh_token.py
new file mode 100644
index 0000000..6179fbf
--- /dev/null
+++ b/server/tests/unit/domain/auth/test_refresh_token.py
@@ -0,0 +1,227 @@
+"""Unit tests for RefreshToken entity."""
+
+from datetime import UTC, datetime, timedelta
+from uuid import uuid4
+
+
+from osa.domain.auth.model.token import RefreshToken
+from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId
+
+
+class TestRefreshTokenCreate:
+ """Tests for RefreshToken.create factory method."""
+
+ def test_create_sets_all_fields(self):
+ """create should set all required fields."""
+ user_id = UserId(uuid4())
+ token_hash = "a" * 64
+ family_id = TokenFamilyId(uuid4())
+
+ token = RefreshToken.create(
+ user_id=user_id,
+ token_hash=token_hash,
+ family_id=family_id,
+ expires_in_days=7,
+ )
+
+ assert token.user_id == user_id
+ assert token.token_hash == token_hash
+ assert token.family_id == family_id
+ assert token.revoked_at is None
+
+ def test_create_generates_id(self):
+ """create should generate a unique ID."""
+ user_id = UserId(uuid4())
+
+ token1 = RefreshToken.create(user_id, "a" * 64, TokenFamilyId(uuid4()))
+ token2 = RefreshToken.create(user_id, "b" * 64, TokenFamilyId(uuid4()))
+
+ assert token1.id != token2.id
+
+ def test_create_sets_expiry_in_future(self):
+ """create should set expires_at in the future."""
+ user_id = UserId(uuid4())
+ now = datetime.now(UTC)
+
+ token = RefreshToken.create(
+ user_id=user_id,
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ expires_in_days=7,
+ )
+
+ assert token.expires_at > now
+ # Should be roughly 7 days from now (allowing small margin)
+ expected = now + timedelta(days=7)
+ assert abs((token.expires_at - expected).total_seconds()) < 1
+
+ def test_create_sets_created_at(self):
+ """create should set created_at to current time."""
+ now = datetime.now(UTC)
+
+ token = RefreshToken.create(
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ )
+
+ assert abs((token.created_at - now).total_seconds()) < 1
+
+
+class TestRefreshTokenIsValid:
+ """Tests for RefreshToken.is_valid property."""
+
+ def make_token(
+ self,
+ expires_at: datetime | None = None,
+ revoked_at: datetime | None = None,
+ ) -> RefreshToken:
+ """Create a token with specified expiry and revocation."""
+ if expires_at is None:
+ expires_at = datetime.now(UTC) + timedelta(days=7)
+
+ return RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=expires_at,
+ created_at=datetime.now(UTC),
+ revoked_at=revoked_at,
+ )
+
+ def test_is_valid_true_for_fresh_token(self):
+ """is_valid should be True for non-expired, non-revoked token."""
+ token = self.make_token()
+
+ assert token.is_valid is True
+
+ def test_is_valid_false_when_expired(self):
+ """is_valid should be False when token is expired."""
+ expired_at = datetime.now(UTC) - timedelta(hours=1)
+ token = self.make_token(expires_at=expired_at)
+
+ assert token.is_valid is False
+
+ def test_is_valid_false_when_revoked(self):
+ """is_valid should be False when token is revoked."""
+ token = self.make_token(revoked_at=datetime.now(UTC))
+
+ assert token.is_valid is False
+
+ def test_is_valid_false_when_both_expired_and_revoked(self):
+ """is_valid should be False when both expired and revoked."""
+ token = self.make_token(
+ expires_at=datetime.now(UTC) - timedelta(hours=1),
+ revoked_at=datetime.now(UTC) - timedelta(hours=2),
+ )
+
+ assert token.is_valid is False
+
+
+class TestRefreshTokenIsRevoked:
+ """Tests for RefreshToken.is_revoked property."""
+
+ def test_is_revoked_false_initially(self):
+ """is_revoked should be False when revoked_at is None."""
+ token = RefreshToken.create(
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ )
+
+ assert token.is_revoked is False
+
+ def test_is_revoked_true_when_set(self):
+ """is_revoked should be True when revoked_at is set."""
+ token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=datetime.now(UTC),
+ )
+
+ assert token.is_revoked is True
+
+
+class TestRefreshTokenIsExpired:
+ """Tests for RefreshToken.is_expired property."""
+
+ def test_is_expired_false_for_future_expiry(self):
+ """is_expired should be False when expires_at is in the future."""
+ token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) + timedelta(days=7),
+ created_at=datetime.now(UTC),
+ revoked_at=None,
+ )
+
+ assert token.is_expired is False
+
+ def test_is_expired_true_for_past_expiry(self):
+ """is_expired should be True when expires_at is in the past."""
+ token = RefreshToken(
+ id=RefreshTokenId(uuid4()),
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ expires_at=datetime.now(UTC) - timedelta(hours=1),
+ created_at=datetime.now(UTC) - timedelta(days=8),
+ revoked_at=None,
+ )
+
+ assert token.is_expired is True
+
+
+class TestRefreshTokenRevoke:
+ """Tests for RefreshToken.revoke method."""
+
+ def test_revoke_sets_revoked_at(self):
+ """revoke should set revoked_at to current time."""
+ token = RefreshToken.create(
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ )
+ assert token.revoked_at is None
+
+ now = datetime.now(UTC)
+ token.revoke()
+
+ assert token.revoked_at is not None
+ assert abs((token.revoked_at - now).total_seconds()) < 1
+
+ def test_revoke_idempotent(self):
+ """revoke should not change revoked_at if already revoked."""
+ token = RefreshToken.create(
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ )
+
+ token.revoke()
+ first_revoked_at = token.revoked_at
+
+ # Second revoke should not change the timestamp
+ token.revoke()
+
+ assert token.revoked_at == first_revoked_at
+
+ def test_revoke_makes_token_invalid(self):
+ """revoke should make is_valid return False."""
+ token = RefreshToken.create(
+ user_id=UserId(uuid4()),
+ token_hash="a" * 64,
+ family_id=TokenFamilyId(uuid4()),
+ )
+ assert token.is_valid is True
+
+ token.revoke()
+
+ assert token.is_valid is False
diff --git a/server/tests/unit/domain/auth/test_token_service.py b/server/tests/unit/domain/auth/test_token_service.py
new file mode 100644
index 0000000..5776511
--- /dev/null
+++ b/server/tests/unit/domain/auth/test_token_service.py
@@ -0,0 +1,250 @@
+"""Unit tests for TokenService JWT creation and validation."""
+
+import time
+from uuid import uuid4
+
+import jwt
+import pytest
+
+from osa.config import JwtConfig
+from osa.domain.auth.model.value import ProviderIdentity, UserId
+from osa.domain.auth.service.token import TokenService
+
+
+class TestTokenServiceAccessToken:
+ """Tests for JWT access token creation and validation."""
+
+ def make_service(self, secret: str = "test-secret-key-256-bits-long-xx") -> TokenService:
+ """Create a TokenService with test config."""
+ config = JwtConfig(
+ secret=secret,
+ algorithm="HS256",
+ access_token_expire_minutes=60,
+ refresh_token_expire_days=7,
+ )
+ return TokenService(_config=config)
+
+ def test_create_access_token_returns_valid_jwt(self):
+ """create_access_token should return a decodable JWT."""
+ service = self.make_service()
+ user_id = UserId(uuid4())
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token = service.create_access_token(user_id, identity)
+
+ # Should be decodable
+ payload = jwt.decode(
+ token,
+ "test-secret-key-256-bits-long-xx",
+ algorithms=["HS256"],
+ audience="authenticated",
+ )
+ assert payload["sub"] == str(user_id)
+ assert payload["provider"] == "orcid"
+ assert payload["external_id"] == "0000-0001-2345-6789"
+ assert payload["aud"] == "authenticated"
+
+ def test_create_access_token_includes_expiry(self):
+ """create_access_token should set exp claim."""
+ service = self.make_service()
+ user_id = UserId(uuid4())
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token = service.create_access_token(user_id, identity)
+
+ payload = jwt.decode(
+ token,
+ "test-secret-key-256-bits-long-xx",
+ algorithms=["HS256"],
+ audience="authenticated",
+ )
+ assert "exp" in payload
+ assert "iat" in payload
+ # Expiry should be ~60 minutes from now
+ assert payload["exp"] > payload["iat"]
+ assert payload["exp"] - payload["iat"] == 60 * 60 # 60 minutes in seconds
+
+ def test_create_access_token_includes_jti(self):
+ """create_access_token should include unique jti claim."""
+ service = self.make_service()
+ user_id = UserId(uuid4())
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token1 = service.create_access_token(user_id, identity)
+ token2 = service.create_access_token(user_id, identity)
+
+ payload1 = jwt.decode(
+ token1,
+ "test-secret-key-256-bits-long-xx",
+ algorithms=["HS256"],
+ audience="authenticated",
+ )
+ payload2 = jwt.decode(
+ token2,
+ "test-secret-key-256-bits-long-xx",
+ algorithms=["HS256"],
+ audience="authenticated",
+ )
+
+ assert "jti" in payload1
+ assert "jti" in payload2
+ assert payload1["jti"] != payload2["jti"]
+
+ def test_create_access_token_with_additional_claims(self):
+ """create_access_token should include additional claims if provided."""
+ service = self.make_service()
+ user_id = UserId(uuid4())
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token = service.create_access_token(
+ user_id,
+ identity,
+ additional_claims={"custom": "value"},
+ )
+
+ payload = jwt.decode(
+ token,
+ "test-secret-key-256-bits-long-xx",
+ algorithms=["HS256"],
+ audience="authenticated",
+ )
+ assert payload["custom"] == "value"
+
+ def test_validate_access_token_returns_payload(self):
+ """validate_access_token should return decoded payload."""
+ service = self.make_service()
+ user_id = UserId(uuid4())
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token = service.create_access_token(user_id, identity)
+ payload = service.validate_access_token(token)
+
+ assert payload["sub"] == str(user_id)
+ assert payload["provider"] == "orcid"
+ assert payload["external_id"] == "0000-0001-2345-6789"
+
+ def test_validate_access_token_rejects_invalid_token(self):
+ """validate_access_token should raise for invalid tokens."""
+ service = self.make_service()
+
+ with pytest.raises(jwt.InvalidTokenError):
+ service.validate_access_token("invalid-token")
+
+ def test_validate_access_token_rejects_wrong_secret(self):
+ """validate_access_token should reject tokens signed with wrong secret."""
+ service1 = self.make_service(secret="secret-one-that-is-long-enough!!")
+ service2 = self.make_service(secret="secret-two-that-is-long-enough!!")
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ token = service1.create_access_token(UserId(uuid4()), identity)
+
+ with pytest.raises(jwt.InvalidTokenError):
+ service2.validate_access_token(token)
+
+ def test_validate_access_token_rejects_expired_token(self):
+ """validate_access_token should reject expired tokens."""
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=0, # Immediate expiry
+ refresh_token_expire_days=7,
+ )
+ service = TokenService(_config=config)
+ identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789")
+
+ # Create token that's already expired
+ token = service.create_access_token(UserId(uuid4()), identity)
+
+ # Small delay to ensure expiry
+ time.sleep(0.1)
+
+ with pytest.raises(jwt.ExpiredSignatureError):
+ service.validate_access_token(token)
+
+
+class TestTokenServiceRefreshToken:
+ """Tests for opaque refresh token creation."""
+
+ def make_service(self) -> TokenService:
+ """Create a TokenService with test config."""
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=60,
+ refresh_token_expire_days=7,
+ )
+ return TokenService(_config=config)
+
+ def test_create_refresh_token_returns_tuple(self):
+ """create_refresh_token should return (raw_token, token_hash)."""
+ service = self.make_service()
+
+ raw_token, token_hash = service.create_refresh_token()
+
+ assert isinstance(raw_token, str)
+ assert isinstance(token_hash, str)
+ assert len(raw_token) > 0
+ assert len(token_hash) == 64 # SHA256 hex = 64 chars
+
+ def test_create_refresh_token_unique_each_time(self):
+ """create_refresh_token should generate unique tokens."""
+ service = self.make_service()
+
+ raw1, hash1 = service.create_refresh_token()
+ raw2, hash2 = service.create_refresh_token()
+
+ assert raw1 != raw2
+ assert hash1 != hash2
+
+ def test_hash_token_consistent(self):
+ """hash_token should produce consistent hashes."""
+ raw_token = "test-token-value"
+
+ hash1 = TokenService.hash_token(raw_token)
+ hash2 = TokenService.hash_token(raw_token)
+
+ assert hash1 == hash2
+ assert len(hash1) == 64
+
+ def test_hash_token_different_for_different_tokens(self):
+ """hash_token should produce different hashes for different tokens."""
+ hash1 = TokenService.hash_token("token-one")
+ hash2 = TokenService.hash_token("token-two")
+
+ assert hash1 != hash2
+
+ def test_created_refresh_token_hash_matches(self):
+ """The hash from create_refresh_token should match hash_token(raw)."""
+ service = self.make_service()
+
+ raw_token, token_hash = service.create_refresh_token()
+
+ assert TokenService.hash_token(raw_token) == token_hash
+
+
+class TestTokenServiceProperties:
+ """Tests for TokenService property accessors."""
+
+ def test_access_token_expire_seconds(self):
+ """access_token_expire_seconds should convert minutes to seconds."""
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=30,
+ refresh_token_expire_days=7,
+ )
+ service = TokenService(_config=config)
+
+ assert service.access_token_expire_seconds == 30 * 60
+
+ def test_refresh_token_expire_days(self):
+ """refresh_token_expire_days should return configured value."""
+ config = JwtConfig(
+ secret="test-secret-key-256-bits-long-xx",
+ algorithm="HS256",
+ access_token_expire_minutes=60,
+ refresh_token_expire_days=14,
+ )
+ service = TokenService(_config=config)
+
+ assert service.refresh_token_expire_days == 14
diff --git a/server/uv.lock b/server/uv.lock
index e888c64..cdf84bd 100644
--- a/server/uv.lock
+++ b/server/uv.lock
@@ -1588,6 +1588,7 @@ dependencies = [
{ name = "psycopg2-binary", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "pydantic", extra = ["email"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "pydantic-settings", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
+ { name = "pyjwt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "rich", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "sentence-transformers", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
{ name = "sqlalchemy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" },
@@ -1623,6 +1624,7 @@ requires-dist = [
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
{ name = "pydantic", extras = ["email"], specifier = ">=2.12.4" },
{ name = "pydantic-settings", specifier = ">=2.12.0" },
+ { name = "pyjwt", specifier = ">=2.11.0" },
{ name = "rich", specifier = ">=14.2.0" },
{ name = "sentence-transformers", specifier = ">=5.2.0" },
{ name = "sqlalchemy", specifier = ">=2.0.44" },
@@ -1948,6 +1950,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
+[[package]]
+name = "pyjwt"
+version = "2.11.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" },
+]
+
[[package]]
name = "pypika"
version = "0.50.0"
diff --git a/web/src/app/auth/callback/page.tsx b/web/src/app/auth/callback/page.tsx
new file mode 100644
index 0000000..700ed82
--- /dev/null
+++ b/web/src/app/auth/callback/page.tsx
@@ -0,0 +1,86 @@
+'use client';
+
+import { Suspense, useEffect, useRef } from 'react';
+import { useRouter, useSearchParams } from 'next/navigation';
+import Link from 'next/link';
+import { OSAClient, parseAuthCallback } from '@/lib/sdk';
+
+function AuthCallbackContent() {
+ const router = useRouter();
+ const searchParams = useSearchParams();
+ const processedRef = useRef(false);
+
+ // Check for error in URL search params
+ const urlError = searchParams.get('error');
+ const errorDescription = searchParams.get('error_description');
+ const error = urlError ? (errorDescription || urlError) : null;
+
+ useEffect(() => {
+ // Prevent double processing in strict mode
+ if (processedRef.current || error) return;
+
+ // Check for auth data in hash
+ const hash = window.location.hash;
+ if (!hash || !hash.includes('auth=')) {
+ return;
+ }
+
+ // Parse and store auth data
+ const params = parseAuthCallback(hash);
+ if (!params) {
+ return;
+ }
+
+ processedRef.current = true;
+
+ // Store in client
+ const client = new OSAClient({ baseUrl: '/api/v1' });
+ client.handleAuthCallback(hash);
+
+ // Redirect to home (or wherever user came from)
+ router.push('/');
+ }, [router, error]);
+
+ if (error) {
+ return (
+ {error} Completing sign in... Loading...Authentication Error
+