From 16d5975ba64a42f80ff62d28104450196c5e3517 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 5 Feb 2026 02:01:21 +0000 Subject: [PATCH 01/17] feat: implement ORCiD OAuth authentication Backend: - Add auth routes (login, callback, refresh, logout) with OAuth 2.0 flow - Add domain models for User, Identity, and RefreshToken - Add AuthService and TokenService with JWT token generation - Add ORCiD identity provider with sandbox support - Add signed state tokens for CSRF protection (HMAC-SHA256) - Add database tables and Alembic migration for auth entities - Add unit tests for auth service, token service, and state signing Frontend: - Add AuthProvider context with SDK integration - Add LoginButton and UserMenu components - Add auth SDK (client, storage, types) for token management - Add auth callback page for OAuth redirect handling Infrastructure: - Simplify Justfile log commands (just logs ) - Fix docker-compose.dev.yml for hot-reload development - Add curl to Dockerfile builder stage for healthchecks - Expose server port 8000 in dev mode for OAuth redirects --- Justfile | 28 +- deploy/docker-compose.dev.yml | 9 +- server/.env.example | 76 ++++ server/Dockerfile | 3 + server/migrations/versions/add_auth_tables.py | 89 ++++ server/osa/application/api/rest/app.py | 3 +- server/osa/application/api/v1/routes/auth.py | 362 ++++++++++++++++ server/osa/application/di.py | 4 + server/osa/config.py | 36 ++ server/osa/domain/auth/command/__init__.py | 19 + server/osa/domain/auth/command/login.py | 93 +++++ server/osa/domain/auth/deps.py | 67 +++ server/osa/domain/auth/event/__init__.py | 5 + server/osa/domain/auth/event/events.py | 19 + server/osa/domain/auth/model/__init__.py | 17 + server/osa/domain/auth/model/identity.py | 46 ++ server/osa/domain/auth/model/token.py | 68 +++ server/osa/domain/auth/model/user.py | 40 ++ server/osa/domain/auth/model/value.py | 91 ++++ server/osa/domain/auth/port/__init__.py | 12 + .../osa/domain/auth/port/identity_provider.py | 64 +++ server/osa/domain/auth/port/repository.py | 79 ++++ server/osa/domain/auth/service/__init__.py | 6 + server/osa/domain/auth/service/auth.py | 251 +++++++++++ server/osa/domain/auth/service/token.py | 112 +++++ .../domain/auth/{adapter => util}/__init__.py | 0 server/osa/domain/auth/util/di/__init__.py | 5 + server/osa/domain/auth/util/di/provider.py | 42 ++ server/osa/infrastructure/auth/__init__.py | 5 + server/osa/infrastructure/auth/di.py | 61 +++ server/osa/infrastructure/auth/orcid.py | 101 +++++ .../persistence/repository/auth.py | 218 ++++++++++ .../osa/infrastructure/persistence/tables.py | 53 +++ server/pyproject.toml | 1 + .../api/v1/routes/test_auth_state.py | 138 ++++++ server/tests/unit/domain/auth/__init__.py | 1 + .../unit/domain/auth/test_auth_service.py | 393 ++++++++++++++++++ .../unit/domain/auth/test_refresh_token.py | 227 ++++++++++ .../unit/domain/auth/test_token_service.py | 243 +++++++++++ server/uv.lock | 11 + web/src/app/auth/callback/page.tsx | 72 ++++ web/src/app/layout.tsx | 12 +- web/src/components/auth/AuthProvider.tsx | 88 ++++ .../components/auth/LoginButton.module.css | 32 ++ web/src/components/auth/LoginButton.tsx | 50 +++ web/src/components/auth/UserMenu.module.css | 120 ++++++ web/src/components/auth/UserMenu.tsx | 101 +++++ web/src/components/layout/AuthButtons.tsx | 19 + web/src/components/layout/Header.tsx | 2 + web/src/hooks/useAuth.ts | 35 ++ web/src/lib/sdk/auth.ts | 213 ++++++++++ web/src/lib/sdk/client.ts | 116 ++++++ web/src/lib/sdk/index.ts | 43 ++ web/src/lib/sdk/storage.ts | 77 ++++ web/src/lib/sdk/types.ts | 64 +++ 55 files changed, 4110 insertions(+), 32 deletions(-) create mode 100644 server/.env.example create mode 100644 server/migrations/versions/add_auth_tables.py create mode 100644 server/osa/application/api/v1/routes/auth.py create mode 100644 server/osa/domain/auth/command/login.py create mode 100644 server/osa/domain/auth/deps.py create mode 100644 server/osa/domain/auth/event/events.py create mode 100644 server/osa/domain/auth/model/identity.py create mode 100644 server/osa/domain/auth/model/token.py create mode 100644 server/osa/domain/auth/model/user.py create mode 100644 server/osa/domain/auth/model/value.py create mode 100644 server/osa/domain/auth/port/identity_provider.py create mode 100644 server/osa/domain/auth/port/repository.py create mode 100644 server/osa/domain/auth/service/auth.py create mode 100644 server/osa/domain/auth/service/token.py rename server/osa/domain/auth/{adapter => util}/__init__.py (100%) create mode 100644 server/osa/domain/auth/util/di/__init__.py create mode 100644 server/osa/domain/auth/util/di/provider.py create mode 100644 server/osa/infrastructure/auth/__init__.py create mode 100644 server/osa/infrastructure/auth/di.py create mode 100644 server/osa/infrastructure/auth/orcid.py create mode 100644 server/osa/infrastructure/persistence/repository/auth.py create mode 100644 server/tests/unit/application/api/v1/routes/test_auth_state.py create mode 100644 server/tests/unit/domain/auth/__init__.py create mode 100644 server/tests/unit/domain/auth/test_auth_service.py create mode 100644 server/tests/unit/domain/auth/test_refresh_token.py create mode 100644 server/tests/unit/domain/auth/test_token_service.py create mode 100644 web/src/app/auth/callback/page.tsx create mode 100644 web/src/components/auth/AuthProvider.tsx create mode 100644 web/src/components/auth/LoginButton.module.css create mode 100644 web/src/components/auth/LoginButton.tsx create mode 100644 web/src/components/auth/UserMenu.module.css create mode 100644 web/src/components/auth/UserMenu.tsx create mode 100644 web/src/components/layout/AuthButtons.tsx create mode 100644 web/src/hooks/useAuth.ts create mode 100644 web/src/lib/sdk/auth.ts create mode 100644 web/src/lib/sdk/client.ts create mode 100644 web/src/lib/sdk/index.ts create mode 100644 web/src/lib/sdk/storage.ts create mode 100644 web/src/lib/sdk/types.ts diff --git a/Justfile b/Justfile index 0add92f..d750323 100644 --- a/Justfile +++ b/Justfile @@ -20,30 +20,10 @@ up-attached: down: docker compose -f deploy/docker-compose.yml down -# View logs from all services -logs: - docker compose -f deploy/docker-compose.yml logs -f - -# View logs from a specific service -logs-service service: +# View logs for a service (e.g., just logs server, just logs db) +logs service: docker compose -f deploy/docker-compose.yml logs -f {{service}} -# View server logs -server-logs: - docker compose -f deploy/docker-compose.yml logs -f server - -# View last N lines of server logs (default: 100) -server-logs-tail lines="100": - docker compose -f deploy/docker-compose.yml logs --tail {{lines}} server - -# View server logs with timestamps -server-logs-time: - docker compose -f deploy/docker-compose.yml logs -f -t server - -# View server logs since a time (e.g., "10m", "1h", "2024-01-01") -server-logs-since since: - docker compose -f deploy/docker-compose.yml logs -f --since {{since}} server - # Shell into the server container server-shell: docker compose -f deploy/docker-compose.yml exec server bash @@ -106,10 +86,6 @@ db-up: db-down: docker compose -f deploy/docker-compose.yml stop db -# View database logs -db-logs: - docker compose -f deploy/docker-compose.yml logs -f db - # Connect to PostgreSQL db-connect: docker compose -f deploy/docker-compose.yml exec db psql -U postgres -d osa diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index 32fa427..94a9c2a 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -7,15 +7,19 @@ services: context: ../server dockerfile: Dockerfile target: builder + ports: + - "8000:8000" volumes: - ../server:/app - /app/.venv environment: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data + OSA_CONFIG_FILE: /app/osa.yaml OSA_LOGGING__LEVEL: ${LOG_LEVEL:-DEBUG} WATCHFILES_FORCE_POLLING: "true" - command: uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload + entrypoint: [] + command: ["sh", "-c", "/app/.venv/bin/alembic upgrade head && /app/.venv/bin/uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload"] healthcheck: test: ["CMD", "curl", "--fail", "http://localhost:8000/api/v1/health"] interval: 10s @@ -36,4 +40,5 @@ services: API_URL: http://server:8000 ports: - "3000:3000" - command: pnpm dev + entrypoint: [] + command: ["pnpm", "dev"] diff --git a/server/.env.example b/server/.env.example new file mode 100644 index 0000000..3dd2841 --- /dev/null +++ b/server/.env.example @@ -0,0 +1,76 @@ +# OSA Server Configuration - Secrets and Environment-Specific Values +# Copy this file to .env and fill in your values +# +# Relationship with osa.yaml: +# - .env: SECRETS (credentials, JWT secrets) - never commit to git +# - osa.yaml: APPLICATION CONFIG (sources, indexes) - can be version-controlled +# +# Priority (highest to lowest): +# 1. Environment variables +# 2. .env file +# 3. OSA_CONFIG_FILE (osa.yaml) +# +# Auth secrets MUST go in .env (not osa.yaml) for security + +# ============================================================================= +# ORCiD OAuth Configuration +# ============================================================================= +# Register at: https://sandbox.orcid.org/developer-tools (sandbox) +# Register at: https://orcid.org/developer-tools (production) + +# ORCiD OAuth client ID (e.g., APP-XXXXXXXXXXXX) +OSA_AUTH__ORCID__CLIENT_ID= + +# ORCiD OAuth client secret +OSA_AUTH__ORCID__CLIENT_SECRET= + +# Use ORCiD sandbox for development (default: true) +# Set to false for production +OSA_AUTH__ORCID__SANDBOX=true + +# ============================================================================= +# JWT Configuration +# ============================================================================= + +# JWT signing secret (generate with: openssl rand -hex 32) +# IMPORTANT: Use a strong, unique secret in production! +OSA_AUTH__JWT__SECRET= + +# Access token expiry in minutes (default: 60) +# OSA_AUTH__JWT__ACCESS_TOKEN_EXPIRE_MINUTES=60 + +# Refresh token expiry in days (default: 7) +# OSA_AUTH__JWT__REFRESH_TOKEN_EXPIRE_DAYS=7 + +# ============================================================================= +# OAuth Callback URL +# ============================================================================= +# Must match the redirect URI registered in ORCiD developer tools + +# For development (default derives from request URL) +# OSA_AUTH__CALLBACK_URL=http://localhost:8000/api/v1/auth/callback + +# For production +# OSA_AUTH__CALLBACK_URL=https://your-domain.com/api/v1/auth/callback + +# ============================================================================= +# Frontend URL +# ============================================================================= +# URL of the frontend application for OAuth redirects + +# For development +OSA_FRONTEND__URL=http://localhost:3000 + +# For production +# OSA_FRONTEND__URL=https://your-domain.com + +# ============================================================================= +# Database Configuration +# ============================================================================= +# Default: SQLite in ~/.local/share/osa/osa.db + +# PostgreSQL example: +# OSA_DATABASE__URL=postgresql+asyncpg://user:password@localhost:5432/osa + +# Echo SQL queries (for debugging) +# OSA_DATABASE__ECHO=false diff --git a/server/Dockerfile b/server/Dockerfile index 6b0d381..ddd05b9 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -3,6 +3,9 @@ # Stage 1: Builder FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder +# Install curl for healthchecks in dev mode +RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* + WORKDIR /app # Enable bytecode compilation for faster startup diff --git a/server/migrations/versions/add_auth_tables.py b/server/migrations/versions/add_auth_tables.py new file mode 100644 index 0000000..b04ac0b --- /dev/null +++ b/server/migrations/versions/add_auth_tables.py @@ -0,0 +1,89 @@ +"""add_auth_tables + +Add users, identities, and refresh_tokens tables for authentication. + +Revision ID: add_auth_tables +Revises: add_worker_columns +Create Date: 2026-02-04 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_auth_tables" +down_revision: Union[str, Sequence[str], None] = "add_worker_columns" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add authentication tables.""" + # USERS TABLE + op.create_table( + "users", + sa.Column("id", sa.String(), nullable=False), + sa.Column("display_name", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # IDENTITIES TABLE + op.create_table( + "identities", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("provider", sa.String(50), nullable=False), + sa.Column("external_id", sa.String(255), nullable=False), + sa.Column("metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"), + ) + op.create_index("ix_identities_user_id", "identities", ["user_id"]) + + # REFRESH TOKENS TABLE + op.create_table( + "refresh_tokens", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("token_hash", sa.String(64), nullable=False), + sa.Column("family_id", sa.String(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + ) + op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"]) + op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"]) + op.create_index("ix_refresh_tokens_family_id", "refresh_tokens", ["family_id"]) + + +def downgrade() -> None: + """Remove authentication tables.""" + # REFRESH TOKENS + op.drop_index("ix_refresh_tokens_family_id", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_token_hash", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_user_id", table_name="refresh_tokens") + op.drop_table("refresh_tokens") + + # IDENTITIES + op.drop_index("ix_identities_user_id", table_name="identities") + op.drop_table("identities") + + # USERS + op.drop_table("users") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index f997a6f..754c7f5 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse from osa.application.api.v1.errors import map_osa_error -from osa.application.api.v1.routes import events, health, records, search, stats, validation +from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation from osa.application.di import create_container from osa.config import Config, configure_logging from osa.domain.shared.error import OSAError @@ -58,6 +58,7 @@ def create_app() -> FastAPI: # Register v1 routes with /api/v1 prefix app_instance.include_router(health.router, prefix="/api/v1") + app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") app_instance.include_router(records.router, prefix="/api/v1") app_instance.include_router(search.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py new file mode 100644 index 0000000..e637cf9 --- /dev/null +++ b/server/osa/application/api/v1/routes/auth.py @@ -0,0 +1,362 @@ +"""Authentication routes for OAuth login flow.""" + +import hashlib +import hmac +import json +import logging +import secrets +import time +from base64 import urlsafe_b64decode, urlsafe_b64encode +from typing import Annotated +from urllib.parse import urlencode + +from dishka import FromDishka +from dishka.integrations.fastapi import DishkaRoute +from fastapi import APIRouter, HTTPException, Query, Request, Response +from fastapi.responses import RedirectResponse +from pydantic import BaseModel + +from osa.config import Config +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth", tags=["Authentication"], route_class=DishkaRoute) + +# OAuth state validity period (5 minutes) +_STATE_EXPIRY_SECONDS = 300 + + +def _create_signed_state(secret: str, redirect_uri: str) -> str: + """Create a signed, self-verifying OAuth state token. + + The state contains: nonce, redirect_uri, expiry timestamp. + Signed with HMAC-SHA256 using the JWT secret. + """ + payload = { + "nonce": secrets.token_urlsafe(16), + "redirect_uri": redirect_uri, + "exp": int(time.time()) + _STATE_EXPIRY_SECONDS, + } + payload_bytes = json.dumps(payload, separators=(",", ":")).encode() + payload_b64 = urlsafe_b64encode(payload_bytes).rstrip(b"=").decode() + + signature = hmac.new(secret.encode(), payload_bytes, hashlib.sha256).digest() + signature_b64 = urlsafe_b64encode(signature).rstrip(b"=").decode() + + return f"{payload_b64}.{signature_b64}" + + +def _verify_signed_state(secret: str, state: str) -> str | None: + """Verify a signed state token and return the redirect_uri if valid. + + Returns None if the state is invalid or expired. + """ + try: + parts = state.split(".") + if len(parts) != 2: + return None + + payload_b64, signature_b64 = parts + + # Restore base64 padding + payload_bytes = urlsafe_b64decode(payload_b64 + "==") + signature = urlsafe_b64decode(signature_b64 + "==") + + # Verify signature + expected_sig = hmac.new(secret.encode(), payload_bytes, hashlib.sha256).digest() + if not hmac.compare_digest(signature, expected_sig): + logger.warning("OAuth state signature verification failed") + return None + + # Parse and check expiry + payload = json.loads(payload_bytes) + if payload.get("exp", 0) < time.time(): + logger.warning("OAuth state expired") + return None + + return payload.get("redirect_uri") + + except Exception as e: + logger.warning("OAuth state verification error: %s", e) + return None + + +class RefreshTokenRequest(BaseModel): + """Request body for token refresh.""" + + refresh_token: str + + +class LogoutRequest(BaseModel): + """Request body for logout.""" + + refresh_token: str + + +class TokenResponse(BaseModel): + """Response containing tokens.""" + + access_token: str + refresh_token: str + token_type: str = "Bearer" + expires_in: int + + +class LogoutResponse(BaseModel): + """Response for logout.""" + + success: bool + + +class UserResponse(BaseModel): + """Response containing user info.""" + + id: str + display_name: str | None + orcid_id: str + + +@router.get("/login") +async def initiate_login( + request: Request, + config: FromDishka[Config], + identity_provider: FromDishka[IdentityProvider], + redirect_uri: Annotated[str | None, Query()] = None, + provider: Annotated[str, Query()] = "orcid", +) -> Response: + """Initiate OAuth login flow. + + Redirects to identity provider's authorization page. + """ + if provider != "orcid": + raise HTTPException( + status_code=400, + detail={ + "code": "invalid_provider", + "message": f"Unsupported provider: {provider}", + }, + ) + + # Determine callback URL + callback_url = config.auth.callback_url + if not callback_url: + # Derive from request URL + callback_url = str(request.url_for("handle_oauth_callback")) + + # Create signed state token (includes redirect_uri, expiry, and nonce) + final_redirect = redirect_uri or config.frontend.url + state = _create_signed_state(config.auth.jwt.secret, final_redirect) + + # Generate authorization URL + authorization_url = identity_provider.get_authorization_url( + state=state, + redirect_uri=callback_url, + ) + + logger.info("OAuth login initiated, redirecting to IdP") + return RedirectResponse(url=authorization_url, status_code=302) + + +@router.get("/callback") +async def handle_oauth_callback( + request: Request, + config: FromDishka[Config], + auth_service: FromDishka[AuthService], + identity_provider: FromDishka[IdentityProvider], + token_service: FromDishka[TokenService], + code: Annotated[str | None, Query()] = None, + state: Annotated[str | None, Query()] = None, + error: Annotated[str | None, Query()] = None, + error_description: Annotated[str | None, Query()] = None, +) -> Response: + """Handle OAuth callback from identity provider. + + Exchanges authorization code for tokens and redirects to frontend. + """ + frontend_url = config.frontend.url + + # Check for OAuth errors + if error: + logger.warning("OAuth error: %s - %s", error, error_description) + error_params = urlencode( + { + "error": error, + "error_description": error_description or "Authentication failed", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + # Validate signed state token + if not state: + logger.warning("OAuth state missing") + error_params = urlencode( + { + "error": "oauth_state_missing", + "error_description": "Missing state parameter", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + final_redirect = _verify_signed_state(config.auth.jwt.secret, state) + if final_redirect is None: + logger.warning("OAuth state invalid or expired") + error_params = urlencode( + { + "error": "oauth_state_invalid", + "error_description": "Invalid or expired state parameter", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + if not code: + logger.warning("OAuth callback missing code") + error_params = urlencode( + { + "error": "missing_code", + "error_description": "Authorization code not provided", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + try: + # Determine callback URL (must match what was used in authorization) + callback_url = config.auth.callback_url + if not callback_url: + callback_url = str(request.url_for("handle_oauth_callback")) + + # Complete OAuth flow + user, identity, access_token, refresh_token = await auth_service.complete_oauth( + provider=identity_provider, + code=code, + redirect_uri=callback_url, + ) + + # Build redirect URL with tokens in fragment + token_params = urlencode( + { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "Bearer", + "expires_in": token_service.access_token_expire_seconds, + "user_id": str(user.id), + "display_name": user.display_name or "", + "orcid_id": identity.external_id, + } + ) + + redirect_url = f"{final_redirect}#auth={token_params}" + logger.info("OAuth complete, user authenticated: user_id=%s", user.id) + return RedirectResponse(url=redirect_url, status_code=302) + + except Exception as e: + logger.exception("OAuth callback failed: %s", e) + error_params = urlencode( + { + "error": "oauth_error", + "error_description": str(e), + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh_token( + body: RefreshTokenRequest, + auth_service: FromDishka[AuthService], + token_service: FromDishka[TokenService], +) -> TokenResponse: + """Refresh access token using refresh token.""" + try: + _user, access_token, new_refresh_token = await auth_service.refresh_tokens( + body.refresh_token + ) + return TokenResponse( + access_token=access_token, + refresh_token=new_refresh_token, + expires_in=token_service.access_token_expire_seconds, + ) + except InvalidStateError as e: + raise HTTPException( + status_code=401, + detail={ + "code": e.code, + "message": e.message, + }, + ) from e + + +@router.post("/logout", response_model=LogoutResponse) +async def logout( + body: LogoutRequest, + auth_service: FromDishka[AuthService], +) -> LogoutResponse: + """Logout and revoke refresh token.""" + success = await auth_service.logout(body.refresh_token) + return LogoutResponse(success=success) + + +@router.get("/me", response_model=UserResponse) +async def get_me( + request: Request, + config: FromDishka[Config], + auth_service: FromDishka[AuthService], +) -> UserResponse: + """Get current authenticated user information.""" + # Extract token from Authorization header + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail={"code": "missing_token", "message": "Authorization header required"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = auth_header[7:] # Remove "Bearer " prefix + + try: + import jwt + + payload = jwt.decode( + token, + config.auth.jwt.secret, + algorithms=[config.auth.jwt.algorithm], + audience="authenticated", + ) + user_id = payload["sub"] + orcid_id = payload["orcid_id"] + + # Get full user info from auth service + from osa.domain.auth.model.value import UserId + from uuid import UUID + + user = await auth_service.get_user_by_id(UserId(UUID(user_id))) + + if user is None: + raise HTTPException( + status_code=401, + detail={"code": "user_not_found", "message": "User not found"}, + ) + + return UserResponse( + id=str(user.id), + display_name=user.display_name, + orcid_id=orcid_id, + ) + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=401, + detail={"code": "token_expired", "message": "Token has expired"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + except jwt.InvalidTokenError: + raise HTTPException( + status_code=401, + detail={"code": "invalid_token", "message": "Invalid token"}, + headers={"WWW-Authenticate": "Bearer"}, + ) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 9f70b33..54cb004 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -2,8 +2,10 @@ from osa.cli.util.paths import OSAPaths from osa.config import Config +from osa.domain.auth.util.di import AuthProvider from osa.domain.deposition.util.di import DepositionProvider from osa.domain.validation.util.di import ValidationProvider +from osa.infrastructure.auth import AuthInfraProvider from osa.infrastructure.event.di import EventProvider from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.source.di import SourceProvider @@ -26,6 +28,8 @@ def create_container() -> AsyncContainer: EventProvider(), DepositionProvider(), ValidationProvider(), + AuthProvider(), + AuthInfraProvider(), context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class ) diff --git a/server/osa/config.py b/server/osa/config.py index feebefd..aca95c3 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -150,6 +150,41 @@ class WorkerConfig(BaseModel): batch_size: int = 100 # Maximum events to fetch per poll cycle +# ============================================================================= +# Authentication Configuration +# ============================================================================= + + +class OrcidConfig(BaseModel): + """ORCiD OAuth configuration.""" + + client_id: str = "" + client_secret: str = "" + sandbox: bool = True # Use sandbox.orcid.org by default + + @property + def base_url(self) -> str: + """Get base URL for ORCiD API based on sandbox setting.""" + return "https://sandbox.orcid.org" if self.sandbox else "https://orcid.org" + + +class JwtConfig(BaseModel): + """JWT configuration.""" + + secret: str = "" # Must be set in production + algorithm: str = "HS256" + access_token_expire_minutes: int = 60 # 1 hour + refresh_token_expire_days: int = 7 + + +class AuthConfig(BaseModel): + """Authentication configuration.""" + + orcid: OrcidConfig = OrcidConfig() + jwt: JwtConfig = JwtConfig() + callback_url: str = "" # Full callback URL (e.g., https://myarchive.org/api/v1/auth/callback) + + class Config(BaseSettings): # These are BaseModel, so env_nested_delimiter handles their env vars server: Server = Server() @@ -157,6 +192,7 @@ class Config(BaseSettings): database: DatabaseConfig = DatabaseConfig() logging: LoggingConfig = LoggingConfig() worker: WorkerConfig = WorkerConfig() # Background worker settings + auth: AuthConfig = AuthConfig() # Authentication settings indexes: list[IndexConfig] = [] # list of index configs sources: list[SourceConfig] = [] # list of source configs diff --git a/server/osa/domain/auth/command/__init__.py b/server/osa/domain/auth/command/__init__.py index e69de29..ef5bb8a 100644 --- a/server/osa/domain/auth/command/__init__.py +++ b/server/osa/domain/auth/command/__init__.py @@ -0,0 +1,19 @@ +"""Auth domain commands.""" + +from .login import ( + CompleteOAuth, + CompleteOAuthHandler, + CompleteOAuthResult, + InitiateLogin, + InitiateLoginHandler, + InitiateLoginResult, +) + +__all__ = [ + "CompleteOAuth", + "CompleteOAuthHandler", + "CompleteOAuthResult", + "InitiateLogin", + "InitiateLoginHandler", + "InitiateLoginResult", +] diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py new file mode 100644 index 0000000..89e6c1d --- /dev/null +++ b/server/osa/domain/auth/command/login.py @@ -0,0 +1,93 @@ +"""Login commands for OAuth authentication flow.""" + +import secrets +from dataclasses import dataclass + +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.auth.service.auth import AuthService +from osa.domain.shared.command import Command, CommandHandler, Result + + +class InitiateLogin(Command): + """Command to start OAuth login flow.""" + + redirect_uri: str # Where to redirect after login + provider: str = "orcid" + + +class InitiateLoginResult(Result): + """Result containing authorization URL and state.""" + + authorization_url: str + state: str # CSRF token - caller should store this for validation + + +@dataclass +class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]): + """Handler for InitiateLogin command.""" + + identity_provider: IdentityProvider + + async def run(self, cmd: InitiateLogin) -> InitiateLoginResult: + """Generate authorization URL for OAuth login.""" + # Generate CSRF state token + state = secrets.token_urlsafe(32) + + # Get authorization URL from identity provider + authorization_url = self.identity_provider.get_authorization_url( + state=state, + redirect_uri=cmd.redirect_uri, + ) + + return InitiateLoginResult( + authorization_url=authorization_url, + state=state, + ) + + +class CompleteOAuth(Command): + """Command to complete OAuth flow with authorization code.""" + + code: str + state: str + redirect_uri: str + + +class CompleteOAuthResult(Result): + """Result containing user info and tokens.""" + + user_id: str + display_name: str | None + orcid_id: str + access_token: str + refresh_token: str + expires_in: int # Seconds until access token expires + + +@dataclass +class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]): + """Handler for CompleteOAuth command.""" + + auth_service: AuthService + identity_provider: IdentityProvider + token_service_expire_seconds: int + + async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: + """Exchange authorization code for tokens and create/update user.""" + # Note: State validation should be done by the caller (route handler) + # before invoking this command + + user, identity, access_token, refresh_token = await self.auth_service.complete_oauth( + provider=self.identity_provider, + code=cmd.code, + redirect_uri=cmd.redirect_uri, + ) + + return CompleteOAuthResult( + user_id=str(user.id), + display_name=user.display_name, + orcid_id=identity.external_id, + access_token=access_token, + refresh_token=refresh_token, + expires_in=self.token_service_expire_seconds, + ) diff --git a/server/osa/domain/auth/deps.py b/server/osa/domain/auth/deps.py new file mode 100644 index 0000000..c5fa52f --- /dev/null +++ b/server/osa/domain/auth/deps.py @@ -0,0 +1,67 @@ +"""FastAPI dependencies for authentication.""" + +from typing import Annotated + +import jwt +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from osa.config import Config +from osa.domain.auth.model.value import UserId + +# HTTP Bearer token security scheme +security = HTTPBearer(auto_error=False) + + +class CurrentUser: + """Authenticated user from JWT token.""" + + def __init__(self, user_id: UserId, orcid_id: str): + self.user_id = user_id + self.orcid_id = orcid_id + + +async def get_current_user( + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], + config: Config, +) -> CurrentUser: + """Extract and validate current user from JWT token. + + Usage in routes: + @router.get("/protected") + async def protected_endpoint( + current_user: Annotated[CurrentUser, Depends(get_current_user)], + ): + ... + """ + if credentials is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"code": "missing_token", "message": "Authorization header required"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jwt.decode( + credentials.credentials, + config.auth.jwt.secret, + algorithms=[config.auth.jwt.algorithm], + audience="authenticated", + ) + user_id = UserId.model_validate(payload["sub"]) + orcid_id = payload["orcid_id"] + return CurrentUser(user_id=user_id, orcid_id=orcid_id) + + except jwt.ExpiredSignatureError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"code": "token_expired", "message": "Token has expired"}, + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + except jwt.InvalidTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"code": "invalid_token", "message": "Invalid token"}, + headers={"WWW-Authenticate": "Bearer"}, + ) from e diff --git a/server/osa/domain/auth/event/__init__.py b/server/osa/domain/auth/event/__init__.py index e69de29..b5e31c1 100644 --- a/server/osa/domain/auth/event/__init__.py +++ b/server/osa/domain/auth/event/__init__.py @@ -0,0 +1,5 @@ +"""Auth domain events.""" + +from .events import UserAuthenticated, UserLoggedOut + +__all__ = ["UserAuthenticated", "UserLoggedOut"] diff --git a/server/osa/domain/auth/event/events.py b/server/osa/domain/auth/event/events.py new file mode 100644 index 0000000..a5f1a2b --- /dev/null +++ b/server/osa/domain/auth/event/events.py @@ -0,0 +1,19 @@ +"""Domain events for the auth domain.""" + +from osa.domain.shared.event import Event, EventId + + +class UserAuthenticated(Event): + """Emitted when a user successfully authenticates.""" + + id: EventId + user_id: str + provider: str + orcid_id: str + + +class UserLoggedOut(Event): + """Emitted when a user logs out.""" + + id: EventId + user_id: str diff --git a/server/osa/domain/auth/model/__init__.py b/server/osa/domain/auth/model/__init__.py index e69de29..e59c5bf 100644 --- a/server/osa/domain/auth/model/__init__.py +++ b/server/osa/domain/auth/model/__init__.py @@ -0,0 +1,17 @@ +"""Auth domain models.""" + +from .identity import Identity +from .token import RefreshToken +from .user import User +from .value import IdentityId, OrcidId, RefreshTokenId, TokenFamilyId, UserId + +__all__ = [ + "Identity", + "IdentityId", + "OrcidId", + "RefreshToken", + "RefreshTokenId", + "TokenFamilyId", + "User", + "UserId", +] diff --git a/server/osa/domain/auth/model/identity.py b/server/osa/domain/auth/model/identity.py new file mode 100644 index 0000000..f86090b --- /dev/null +++ b/server/osa/domain/auth/model/identity.py @@ -0,0 +1,46 @@ +"""Identity entity for the auth domain.""" + +from datetime import UTC, datetime +from typing import Any + +from osa.domain.auth.model.value import IdentityId, UserId +from osa.domain.shared.model.entity import Entity + + +class Identity(Entity): + """A link between a User and an external identity provider. + + Examples: + - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" + - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" + + Invariants: + - `(provider, external_id)` is globally unique + - `user_id` is immutable after creation + - `provider` and `external_id` are immutable after creation + """ + + id: IdentityId + user_id: UserId + provider: str + external_id: str + metadata: dict[str, Any] | None = None # Provider-specific data (name, email) + created_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + provider: str, + external_id: str, + metadata: dict[str, Any] | None = None, + ) -> "Identity": + """Create a new identity link.""" + return cls( + id=IdentityId.generate(), + user_id=user_id, + provider=provider, + external_id=external_id, + metadata=metadata, + created_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/model/token.py b/server/osa/domain/auth/model/token.py new file mode 100644 index 0000000..409214c --- /dev/null +++ b/server/osa/domain/auth/model/token.py @@ -0,0 +1,68 @@ +"""RefreshToken entity for the auth domain.""" + +from datetime import UTC, datetime, timedelta + +from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId +from osa.domain.shared.model.entity import Entity + + +class RefreshToken(Entity): + """An opaque refresh token for session management. + + Tokens belong to a "family" for theft detection. When a token is refreshed, + the new token inherits the family_id. If a revoked token is reused, the + entire family is revoked (indicating potential theft). + + Invariants: + - `token_hash` is a SHA256 hash (64 hex characters) + - `expires_at` is always in the future at creation time + - Once `revoked_at` is set, it cannot be unset + """ + + id: RefreshTokenId + user_id: UserId + token_hash: str # SHA256 hash of the actual token value + family_id: TokenFamilyId + expires_at: datetime + created_at: datetime + revoked_at: datetime | None = None + + @property + def is_valid(self) -> bool: + """Token is valid if not revoked and not expired.""" + return self.revoked_at is None and self.expires_at > datetime.now(UTC) + + @property + def is_revoked(self) -> bool: + """Check if the token has been revoked.""" + return self.revoked_at is not None + + @property + def is_expired(self) -> bool: + """Check if the token has expired.""" + return self.expires_at <= datetime.now(UTC) + + def revoke(self) -> None: + """Mark this token as revoked.""" + if self.revoked_at is None: + self.revoked_at = datetime.now(UTC) + + @classmethod + def create( + cls, + user_id: UserId, + token_hash: str, + family_id: TokenFamilyId, + expires_in_days: int = 7, + ) -> "RefreshToken": + """Create a new refresh token.""" + now = datetime.now(UTC) + return cls( + id=RefreshTokenId.generate(), + user_id=user_id, + token_hash=token_hash, + family_id=family_id, + expires_at=now + timedelta(days=expires_in_days), + created_at=now, + revoked_at=None, + ) diff --git a/server/osa/domain/auth/model/user.py b/server/osa/domain/auth/model/user.py new file mode 100644 index 0000000..95c8ca0 --- /dev/null +++ b/server/osa/domain/auth/model/user.py @@ -0,0 +1,40 @@ +"""User aggregate for the auth domain.""" + +from datetime import UTC, datetime + +from osa.domain.auth.model.value import UserId +from osa.domain.shared.model.aggregate import Aggregate + + +class User(Aggregate): + """An authenticated user in the OSA system. + + Users are created on first authentication via any identity provider. + A user may have multiple linked identities (e.g., ORCiD + institutional SAML). + + Invariants: + - `id` is immutable after creation + - `created_at` is immutable after creation + - `updated_at` is set on any modification + """ + + id: UserId + display_name: str | None + created_at: datetime + updated_at: datetime | None = None + + @classmethod + def create(cls, display_name: str | None = None) -> "User": + """Create a new user.""" + now = datetime.now(UTC) + return cls( + id=UserId.generate(), + display_name=display_name, + created_at=now, + updated_at=None, + ) + + def update_display_name(self, display_name: str | None) -> None: + """Update the user's display name.""" + self.display_name = display_name + self.updated_at = datetime.now(UTC) diff --git a/server/osa/domain/auth/model/value.py b/server/osa/domain/auth/model/value.py new file mode 100644 index 0000000..df1252b --- /dev/null +++ b/server/osa/domain/auth/model/value.py @@ -0,0 +1,91 @@ +"""Value objects for the auth domain.""" + +import re +from uuid import UUID, uuid4 + +from pydantic import RootModel, field_validator + + +class UserId(RootModel[UUID]): + """Unique identifier for a User.""" + + @classmethod + def generate(cls) -> "UserId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class IdentityId(RootModel[UUID]): + """Unique identifier for an Identity.""" + + @classmethod + def generate(cls) -> "IdentityId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class RefreshTokenId(RootModel[UUID]): + """Unique identifier for a RefreshToken.""" + + @classmethod + def generate(cls) -> "RefreshTokenId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class TokenFamilyId(RootModel[UUID]): + """Identifier for a token family. + + All refresh tokens from a single login session share a family_id. + Used for theft detection: if a revoked token is reused, the entire + family is invalidated. + """ + + @classmethod + def generate(cls) -> "TokenFamilyId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +ORCID_PATTERN = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$") + + +class OrcidId(RootModel[str]): + """An ORCiD identifier (e.g., 0000-0001-2345-6789). + + ORCiD IDs are 16-digit numbers displayed as four groups of four, + with a checksum character (digit or X) at the end. + """ + + @field_validator("root") + @classmethod + def validate_orcid_format(cls, v: str) -> str: + if not ORCID_PATTERN.match(v): + raise ValueError(f"Invalid ORCiD format: {v}") + return v + + def __str__(self) -> str: + return self.root + + def __hash__(self) -> int: + return hash(self.root) diff --git a/server/osa/domain/auth/port/__init__.py b/server/osa/domain/auth/port/__init__.py index e69de29..9f67041 100644 --- a/server/osa/domain/auth/port/__init__.py +++ b/server/osa/domain/auth/port/__init__.py @@ -0,0 +1,12 @@ +"""Auth domain ports.""" + +from .identity_provider import IdentityInfo, IdentityProvider +from .repository import IdentityRepository, RefreshTokenRepository, UserRepository + +__all__ = [ + "IdentityInfo", + "IdentityProvider", + "IdentityRepository", + "RefreshTokenRepository", + "UserRepository", +] diff --git a/server/osa/domain/auth/port/identity_provider.py b/server/osa/domain/auth/port/identity_provider.py new file mode 100644 index 0000000..24c7305 --- /dev/null +++ b/server/osa/domain/auth/port/identity_provider.py @@ -0,0 +1,64 @@ +"""Identity provider port for the auth domain.""" + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Protocol + +from osa.domain.shared.port import Port + + +@dataclass(frozen=True) +class IdentityInfo: + """Information returned by an identity provider after successful auth.""" + + provider: str # e.g., "orcid", "google", "saml" + external_id: str # Provider-specific user ID + display_name: str | None + email: str | None # May not be available from all providers + raw_data: dict[str, Any] # Full provider response for extensibility + + +class IdentityProvider(Port, Protocol): + """Port for external identity provider integrations. + + Implementations are adapters in infrastructure/ (e.g., OrcidIdentityProvider). + """ + + @property + @abstractmethod + def provider_name(self) -> str: + """Unique identifier for this provider (e.g., 'orcid').""" + ... + + @abstractmethod + def get_authorization_url(self, state: str, redirect_uri: str) -> str: + """Generate URL to redirect user for authentication. + + Args: + state: CSRF protection token (random, stored in session) + redirect_uri: Where the IdP should redirect after auth + + Returns: + Full URL to redirect the user to + """ + ... + + @abstractmethod + async def exchange_code( + self, + code: str, + redirect_uri: str, + ) -> IdentityInfo: + """Exchange authorization code for identity information. + + Args: + code: Authorization code from IdP callback + redirect_uri: Must match the redirect_uri used in authorization URL + + Returns: + IdentityInfo with user details from the provider + + Raises: + ExternalServiceError: If the IdP request fails + """ + ... diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py new file mode 100644 index 0000000..7f70a70 --- /dev/null +++ b/server/osa/domain/auth/port/repository.py @@ -0,0 +1,79 @@ +"""Repository ports for the auth domain.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + IdentityId, + RefreshTokenId, + TokenFamilyId, + UserId, +) +from osa.domain.shared.port import Port + + +class UserRepository(Port, Protocol): + """Repository for User aggregate persistence.""" + + @abstractmethod + async def get(self, user_id: UserId) -> User | None: + """Get a user by ID.""" + ... + + @abstractmethod + async def save(self, user: User) -> None: + """Save a user (create or update).""" + ... + + +class IdentityRepository(Port, Protocol): + """Repository for Identity entity persistence.""" + + @abstractmethod + async def get(self, identity_id: IdentityId) -> Identity | None: + """Get an identity by ID.""" + ... + + @abstractmethod + async def get_by_provider_and_external_id( + self, provider: str, external_id: str + ) -> Identity | None: + """Get an identity by provider and external ID.""" + ... + + @abstractmethod + async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + """Get all identities for a user.""" + ... + + @abstractmethod + async def save(self, identity: Identity) -> None: + """Save an identity.""" + ... + + +class RefreshTokenRepository(Port, Protocol): + """Repository for RefreshToken entity persistence.""" + + @abstractmethod + async def get(self, token_id: RefreshTokenId) -> RefreshToken | None: + """Get a refresh token by ID.""" + ... + + @abstractmethod + async def get_by_token_hash(self, token_hash: str) -> RefreshToken | None: + """Get a refresh token by its hash.""" + ... + + @abstractmethod + async def save(self, token: RefreshToken) -> None: + """Save a refresh token.""" + ... + + @abstractmethod + async def revoke_family(self, family_id: TokenFamilyId) -> int: + """Revoke all tokens in a family. Returns count of revoked tokens.""" + ... diff --git a/server/osa/domain/auth/service/__init__.py b/server/osa/domain/auth/service/__init__.py index e69de29..0e01a6e 100644 --- a/server/osa/domain/auth/service/__init__.py +++ b/server/osa/domain/auth/service/__init__.py @@ -0,0 +1,6 @@ +"""Auth domain services.""" + +from .auth import AuthService +from .token import TokenService + +__all__ = ["AuthService", "TokenService"] diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py new file mode 100644 index 0000000..620ee44 --- /dev/null +++ b/server/osa/domain/auth/service/auth.py @@ -0,0 +1,251 @@ +"""Auth service for orchestrating authentication flows.""" + +import logging + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import TokenFamilyId, UserId +from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.service import Service + +logger = logging.getLogger(__name__) + + +class AuthService(Service): + """Orchestrates authentication flows. + + - initiate_login: Generate authorization URL + - complete_oauth: Exchange code for tokens, create/update user + - refresh_tokens: Issue new tokens from refresh token + - logout: Revoke refresh token family + """ + + _user_repo: UserRepository + _identity_repo: IdentityRepository + _refresh_token_repo: RefreshTokenRepository + _token_service: TokenService + _outbox: Outbox + + async def initiate_login( + self, + provider: IdentityProvider, + state: str, + redirect_uri: str, + ) -> str: + """Generate the authorization URL for OAuth login. + + Args: + provider: The identity provider to use + state: CSRF protection token (caller should store this) + redirect_uri: Where the IdP should redirect after auth + + Returns: + Authorization URL to redirect the user to + """ + return provider.get_authorization_url(state, redirect_uri) + + async def complete_oauth( + self, + provider: IdentityProvider, + code: str, + redirect_uri: str, + ) -> tuple[User, Identity, str, str]: + """Complete OAuth flow and issue tokens. + + Args: + provider: The identity provider + code: Authorization code from callback + redirect_uri: Must match the one used in authorization + + Returns: + Tuple of (user, identity, access_token, refresh_token) + """ + # Exchange code for identity info + identity_info = await provider.exchange_code(code, redirect_uri) + + # Find or create user and identity + user, identity = await self._find_or_create_user(identity_info) + + # Create tokens + access_token, refresh_token = await self._create_tokens(user, identity) + + logger.info( + "User authenticated: user_id=%s, provider=%s, external_id=%s", + user.id, + identity.provider, + identity.external_id, + ) + + return user, identity, access_token, refresh_token + + async def refresh_tokens( + self, + refresh_token_raw: str, + ) -> tuple[User, str, str]: + """Refresh access token using refresh token. + + Implements token rotation: old refresh token is revoked, + new one issued in same family. + + Args: + refresh_token_raw: The raw refresh token from client + + Returns: + Tuple of (user, new_access_token, new_refresh_token) + + Raises: + InvalidStateError: If refresh token is invalid, expired, or revoked + """ + from osa.domain.shared.error import InvalidStateError + + token_hash = self._token_service.hash_token(refresh_token_raw) + stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash) + + if stored_token is None: + raise InvalidStateError("Invalid refresh token", code="invalid_refresh_token") + + if stored_token.is_revoked: + # Potential theft detected - revoke entire family + await self._refresh_token_repo.revoke_family(stored_token.family_id) + logger.warning( + "Refresh token reuse detected, family revoked: family_id=%s", + stored_token.family_id, + ) + raise InvalidStateError( + "Token family revoked - please login again", + code="token_family_revoked", + ) + + if stored_token.is_expired: + raise InvalidStateError("Refresh token expired", code="refresh_token_expired") + + # Revoke old token + stored_token.revoke() + await self._refresh_token_repo.save(stored_token) + + # Get user and their ORCiD identity + user = await self._user_repo.get(stored_token.user_id) + if user is None: + raise InvalidStateError("User not found", code="user_not_found") + + identities = await self._identity_repo.get_by_user_id(user.id) + orcid_identity = next((i for i in identities if i.provider == "orcid"), None) + + if orcid_identity is None: + raise InvalidStateError("User has no ORCiD identity", code="no_orcid_identity") + + # Issue new tokens in same family + raw_token, token_hash = self._token_service.create_refresh_token() + new_refresh_token = RefreshToken.create( + user_id=user.id, + token_hash=token_hash, + family_id=stored_token.family_id, + expires_in_days=self._token_service.refresh_token_expire_days, + ) + await self._refresh_token_repo.save(new_refresh_token) + + access_token = self._token_service.create_access_token( + user_id=user.id, + orcid_id=orcid_identity.external_id, + ) + + logger.info("Tokens refreshed: user_id=%s", user.id) + + return user, access_token, raw_token + + async def logout(self, refresh_token_raw: str) -> bool: + """Logout by revoking refresh token family. + + Args: + refresh_token_raw: The raw refresh token + + Returns: + True if tokens were revoked + """ + token_hash = self._token_service.hash_token(refresh_token_raw) + stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash) + + if stored_token is None: + # Token not found, but logout succeeds anyway + return True + + revoked_count = await self._refresh_token_repo.revoke_family(stored_token.family_id) + logger.info( + "User logged out: user_id=%s, revoked_tokens=%d", + stored_token.user_id, + revoked_count, + ) + + return True + + async def get_user_by_id(self, user_id: UserId) -> User | None: + """Get a user by their ID.""" + return await self._user_repo.get(user_id) + + async def get_orcid_identity(self, user_id: UserId) -> Identity | None: + """Get the ORCiD identity for a user.""" + identities = await self._identity_repo.get_by_user_id(user_id) + return next((i for i in identities if i.provider == "orcid"), None) + + async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, Identity]: + """Find existing user by identity or create new one.""" + # Check if identity already exists + existing_identity = await self._identity_repo.get_by_provider_and_external_id( + identity_info.provider, identity_info.external_id + ) + + if existing_identity: + # User exists, return them + user = await self._user_repo.get(existing_identity.user_id) + if user is None: + # Orphaned identity - shouldn't happen with CASCADE + raise RuntimeError(f"Identity exists without user: {existing_identity.id}") + return user, existing_identity + + # Create new user and identity + user = User.create(display_name=identity_info.display_name) + await self._user_repo.save(user) + + identity = Identity.create( + user_id=user.id, + provider=identity_info.provider, + external_id=identity_info.external_id, + metadata=identity_info.raw_data, + ) + await self._identity_repo.save(identity) + + logger.info( + "New user created: user_id=%s, provider=%s", + user.id, + identity_info.provider, + ) + + return user, identity + + async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str]: + """Create access and refresh tokens for a user.""" + # Create refresh token + raw_token, token_hash = self._token_service.create_refresh_token() + refresh_token = RefreshToken.create( + user_id=user.id, + token_hash=token_hash, + family_id=TokenFamilyId.generate(), + expires_in_days=self._token_service.refresh_token_expire_days, + ) + await self._refresh_token_repo.save(refresh_token) + + # Create access token + access_token = self._token_service.create_access_token( + user_id=user.id, + orcid_id=identity.external_id, + ) + + return access_token, raw_token diff --git a/server/osa/domain/auth/service/token.py b/server/osa/domain/auth/service/token.py new file mode 100644 index 0000000..b61bbd8 --- /dev/null +++ b/server/osa/domain/auth/service/token.py @@ -0,0 +1,112 @@ +"""Token service for JWT creation and validation.""" + +import hashlib +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any + +import jwt + +from osa.config import JwtConfig +from osa.domain.auth.model.value import UserId +from osa.domain.shared.service import Service + + +class TokenService(Service): + """Service for JWT access token and refresh token operations. + + - Access tokens are JWTs (HS256) with user claims + - Refresh tokens are opaque random strings, stored as hashes in the database + """ + + _config: JwtConfig + + def create_access_token( + self, + user_id: UserId, + orcid_id: str, + additional_claims: dict[str, Any] | None = None, + ) -> str: + """Create a JWT access token. + + Args: + user_id: The user's internal ID + orcid_id: The user's ORCiD ID + additional_claims: Optional extra claims to include + + Returns: + Encoded JWT string + """ + now = datetime.now(UTC) + expires_at = now + timedelta(minutes=self._config.access_token_expire_minutes) + + payload = { + "sub": str(user_id), + "orcid_id": orcid_id, + "aud": "authenticated", + "iat": int(now.timestamp()), + "exp": int(expires_at.timestamp()), + "jti": secrets.token_hex(16), + } + + if additional_claims: + payload.update(additional_claims) + + return jwt.encode( + payload, + self._config.secret, + algorithm=self._config.algorithm, + ) + + def validate_access_token(self, token: str) -> dict[str, Any]: + """Validate and decode a JWT access token. + + Args: + token: The JWT string to validate + + Returns: + Decoded payload dict + + Raises: + jwt.InvalidTokenError: If token is invalid or expired + """ + return jwt.decode( + token, + self._config.secret, + algorithms=[self._config.algorithm], + audience="authenticated", + ) + + def create_refresh_token(self) -> tuple[str, str]: + """Create a new refresh token. + + Returns: + Tuple of (raw_token, token_hash) + - raw_token: Send to client + - token_hash: Store in database + """ + raw_token = secrets.token_urlsafe(32) + token_hash = self.hash_token(raw_token) + return raw_token, token_hash + + @staticmethod + def hash_token(raw_token: str) -> str: + """Create SHA256 hash of a token. + + Args: + raw_token: The raw token string + + Returns: + Hex-encoded SHA256 hash (64 characters) + """ + return hashlib.sha256(raw_token.encode()).hexdigest() + + @property + def access_token_expire_seconds(self) -> int: + """Get access token expiry in seconds.""" + return self._config.access_token_expire_minutes * 60 + + @property + def refresh_token_expire_days(self) -> int: + """Get refresh token expiry in days.""" + return self._config.refresh_token_expire_days diff --git a/server/osa/domain/auth/adapter/__init__.py b/server/osa/domain/auth/util/__init__.py similarity index 100% rename from server/osa/domain/auth/adapter/__init__.py rename to server/osa/domain/auth/util/__init__.py diff --git a/server/osa/domain/auth/util/di/__init__.py b/server/osa/domain/auth/util/di/__init__.py new file mode 100644 index 0000000..22bb012 --- /dev/null +++ b/server/osa/domain/auth/util/di/__init__.py @@ -0,0 +1,5 @@ +"""DI providers for auth domain.""" + +from .provider import AuthProvider + +__all__ = ["AuthProvider"] diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py new file mode 100644 index 0000000..d564f86 --- /dev/null +++ b/server/osa/domain/auth/util/di/provider.py @@ -0,0 +1,42 @@ +"""DI provider for auth domain.""" + +from dishka import provide + +from osa.config import Config +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.outbox import Outbox +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class AuthProvider(Provider): + """DI provider for auth domain services and handlers.""" + + @provide(scope=Scope.UOW) + def get_token_service(self, config: Config) -> TokenService: + """Provide TokenService.""" + return TokenService(_config=config.auth.jwt) + + @provide(scope=Scope.UOW) + def get_auth_service( + self, + user_repo: UserRepository, + identity_repo: IdentityRepository, + refresh_token_repo: RefreshTokenRepository, + token_service: TokenService, + outbox: Outbox, + ) -> AuthService: + """Provide AuthService.""" + return AuthService( + _user_repo=user_repo, + _identity_repo=identity_repo, + _refresh_token_repo=refresh_token_repo, + _token_service=token_service, + _outbox=outbox, + ) diff --git a/server/osa/infrastructure/auth/__init__.py b/server/osa/infrastructure/auth/__init__.py new file mode 100644 index 0000000..fc36309 --- /dev/null +++ b/server/osa/infrastructure/auth/__init__.py @@ -0,0 +1,5 @@ +"""Auth infrastructure adapters.""" + +from .di import AuthInfraProvider + +__all__ = ["AuthInfraProvider"] diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py new file mode 100644 index 0000000..2f08d77 --- /dev/null +++ b/server/osa/infrastructure/auth/di.py @@ -0,0 +1,61 @@ +"""DI provider for auth infrastructure.""" + +import httpx +from dishka import provide + +from osa.config import Config +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.infrastructure.auth.orcid import OrcidIdentityProvider +from osa.infrastructure.persistence.repository.auth import ( + PostgresIdentityRepository, + PostgresRefreshTokenRepository, + PostgresUserRepository, +) +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + +# HTTP client timeout configuration +_HTTP_TIMEOUT = httpx.Timeout( + connect=5.0, # Connection timeout + read=10.0, # Read timeout + write=5.0, # Write timeout + pool=5.0, # Pool timeout +) + + +class AuthInfraProvider(Provider): + """DI provider for auth infrastructure adapters.""" + + # Repository adapters + user_repo = provide( + PostgresUserRepository, + scope=Scope.UOW, + provides=UserRepository, + ) + identity_repo = provide( + PostgresIdentityRepository, + scope=Scope.UOW, + provides=IdentityRepository, + ) + refresh_token_repo = provide( + PostgresRefreshTokenRepository, + scope=Scope.UOW, + provides=RefreshTokenRepository, + ) + + @provide(scope=Scope.APP) + def get_auth_http_client(self) -> httpx.AsyncClient: + """Shared HTTP client for auth operations (connection pooling).""" + return httpx.AsyncClient(timeout=_HTTP_TIMEOUT) + + @provide(scope=Scope.UOW) + def get_orcid_provider( + self, config: Config, http_client: httpx.AsyncClient + ) -> IdentityProvider: + """Provide OrcidIdentityProvider as the default IdentityProvider.""" + return OrcidIdentityProvider(config=config.auth.orcid, http_client=http_client) diff --git a/server/osa/infrastructure/auth/orcid.py b/server/osa/infrastructure/auth/orcid.py new file mode 100644 index 0000000..a5db378 --- /dev/null +++ b/server/osa/infrastructure/auth/orcid.py @@ -0,0 +1,101 @@ +"""ORCiD identity provider adapter.""" + +import logging +from urllib.parse import urlencode + +import httpx + +from osa.config import OrcidConfig +from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider +from osa.domain.shared.error import ExternalServiceError + +logger = logging.getLogger(__name__) + + +class OrcidIdentityProvider(IdentityProvider): + """IdentityProvider implementation for ORCiD OAuth.""" + + def __init__(self, config: OrcidConfig, http_client: httpx.AsyncClient) -> None: + self._config = config + self._http = http_client + + @property + def provider_name(self) -> str: + return "orcid" + + def get_authorization_url(self, state: str, redirect_uri: str) -> str: + """Generate ORCiD authorization URL.""" + params = { + "client_id": self._config.client_id, + "response_type": "code", + "scope": "/authenticate", + "redirect_uri": redirect_uri, + "state": state, + } + return f"{self._config.base_url}/oauth/authorize?{urlencode(params)}" + + async def exchange_code( + self, + code: str, + redirect_uri: str, + ) -> IdentityInfo: + """Exchange authorization code for identity information.""" + token_url = f"{self._config.base_url}/oauth/token" + + data = { + "client_id": self._config.client_id, + "client_secret": self._config.client_secret, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + + try: + response = await self._http.post( + token_url, + data=data, + headers={"Accept": "application/json"}, + ) + + if response.status_code != 200: + logger.error( + "ORCiD token exchange failed: status=%d, body=%s", + response.status_code, + response.text, + ) + raise ExternalServiceError( + f"ORCiD token exchange failed: {response.status_code}", + code="idp_unavailable", + ) + + token_data = response.json() + + except httpx.RequestError as e: + logger.exception("ORCiD request failed: %s", e) + raise ExternalServiceError( + "Failed to connect to ORCiD", + code="idp_unavailable", + ) from e + + # ORCiD returns user info directly in token response + # { + # "access_token": "...", + # "token_type": "bearer", + # "scope": "/authenticate", + # "name": "Jane Doe", + # "orcid": "0000-0001-2345-6789" + # } + orcid_id = token_data.get("orcid") + if not orcid_id: + raise ExternalServiceError( + "ORCiD response missing orcid field", + code="oauth_error", + ) + + return IdentityInfo( + provider="orcid", + external_id=orcid_id, + display_name=token_data.get("name"), + email=None, # ORCiD doesn't return email in basic auth + raw_data=token_data, + ) diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py new file mode 100644 index 0000000..56f37eb --- /dev/null +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -0,0 +1,218 @@ +"""PostgreSQL repository implementations for auth domain.""" + +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + IdentityId, + RefreshTokenId, + TokenFamilyId, + UserId, +) +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.infrastructure.persistence.tables import ( + identities_table, + refresh_tokens_table, + users_table, +) + + +def _row_to_user(row: dict) -> User: + """Convert a database row to a User model.""" + return User( + id=UserId(UUID(row["id"])), + display_name=row["display_name"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + +def _user_to_dict(user: User) -> dict: + """Convert a User model to a database row dict.""" + return { + "id": str(user.id), + "display_name": user.display_name, + "created_at": user.created_at, + "updated_at": user.updated_at, + } + + +def _row_to_identity(row: dict) -> Identity: + """Convert a database row to an Identity model.""" + return Identity( + id=IdentityId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + provider=row["provider"], + external_id=row["external_id"], + metadata=row["metadata"], + created_at=row["created_at"], + ) + + +def _identity_to_dict(identity: Identity) -> dict: + """Convert an Identity model to a database row dict.""" + return { + "id": str(identity.id), + "user_id": str(identity.user_id), + "provider": identity.provider, + "external_id": identity.external_id, + "metadata": identity.metadata, + "created_at": identity.created_at, + } + + +def _row_to_refresh_token(row: dict) -> RefreshToken: + """Convert a database row to a RefreshToken model.""" + return RefreshToken( + id=RefreshTokenId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + token_hash=row["token_hash"], + family_id=TokenFamilyId(UUID(row["family_id"])), + expires_at=row["expires_at"], + created_at=row["created_at"], + revoked_at=row["revoked_at"], + ) + + +def _refresh_token_to_dict(token: RefreshToken) -> dict: + """Convert a RefreshToken model to a database row dict.""" + return { + "id": str(token.id), + "user_id": str(token.user_id), + "token_hash": token.token_hash, + "family_id": str(token.family_id), + "expires_at": token.expires_at, + "created_at": token.created_at, + "revoked_at": token.revoked_at, + } + + +class PostgresUserRepository(UserRepository): + """PostgreSQL implementation of UserRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, user_id: UserId) -> User | None: + stmt = select(users_table).where(users_table.c.id == str(user_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_user(dict(row)) if row else None + + async def save(self, user: User) -> None: + user_dict = _user_to_dict(user) + existing = await self.get(user.id) + + if existing: + stmt = update(users_table).where(users_table.c.id == str(user.id)).values(**user_dict) + else: + stmt = insert(users_table).values(**user_dict) + + await self.session.execute(stmt) + await self.session.flush() + + +class PostgresIdentityRepository(IdentityRepository): + """PostgreSQL implementation of IdentityRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, identity_id: IdentityId) -> Identity | None: + stmt = select(identities_table).where(identities_table.c.id == str(identity_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_identity(dict(row)) if row else None + + async def get_by_provider_and_external_id( + self, provider: str, external_id: str + ) -> Identity | None: + stmt = select(identities_table).where( + identities_table.c.provider == provider, + identities_table.c.external_id == external_id, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_identity(dict(row)) if row else None + + async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + stmt = select(identities_table).where(identities_table.c.user_id == str(user_id)) + result = await self.session.execute(stmt) + rows = result.mappings().all() + return [_row_to_identity(dict(row)) for row in rows] + + async def save(self, identity: Identity) -> None: + identity_dict = _identity_to_dict(identity) + existing = await self.get(identity.id) + + if existing: + stmt = ( + update(identities_table) + .where(identities_table.c.id == str(identity.id)) + .values(**identity_dict) + ) + else: + stmt = insert(identities_table).values(**identity_dict) + + await self.session.execute(stmt) + await self.session.flush() + + +class PostgresRefreshTokenRepository(RefreshTokenRepository): + """PostgreSQL implementation of RefreshTokenRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, token_id: RefreshTokenId) -> RefreshToken | None: + stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.id == str(token_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_refresh_token(dict(row)) if row else None + + async def get_by_token_hash(self, token_hash: str) -> RefreshToken | None: + stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.token_hash == token_hash) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_refresh_token(dict(row)) if row else None + + async def save(self, token: RefreshToken) -> None: + token_dict = _refresh_token_to_dict(token) + existing = await self.get(token.id) + + if existing: + stmt = ( + update(refresh_tokens_table) + .where(refresh_tokens_table.c.id == str(token.id)) + .values(**token_dict) + ) + else: + stmt = insert(refresh_tokens_table).values(**token_dict) + + await self.session.execute(stmt) + await self.session.flush() + + async def revoke_family(self, family_id: TokenFamilyId) -> int: + """Revoke all tokens in a family. Returns count of revoked tokens.""" + now = datetime.now(UTC) + stmt = ( + update(refresh_tokens_table) + .where( + refresh_tokens_table.c.family_id == str(family_id), + refresh_tokens_table.c.revoked_at.is_(None), + ) + .values(revoked_at=now) + ) + result = await self.session.execute(stmt) + await self.session.flush() + return result.rowcount diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index 6b66044..f22ee64 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -3,12 +3,14 @@ from sqlalchemy import ( Column, DateTime, + ForeignKey, Index, Integer, MetaData, String, Table, Text, + UniqueConstraint, text, ) from sqlalchemy.types import JSON @@ -122,3 +124,54 @@ events_table.c.created_at, postgresql_where=text("delivery_status = 'failed'"), ) + + +# ============================================================================ +# USERS TABLE (Authentication) +# ============================================================================ +users_table = Table( + "users", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("display_name", String(255), nullable=True), + Column("created_at", DateTime(timezone=True), nullable=False), + Column("updated_at", DateTime(timezone=True), nullable=True), +) + + +# ============================================================================ +# IDENTITIES TABLE (Authentication) +# ============================================================================ +identities_table = Table( + "identities", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("provider", String(50), nullable=False), # "orcid", "google", etc. + Column("external_id", String(255), nullable=False), # ORCiD ID, Google ID, etc. + Column("metadata", JSON, nullable=True), # Provider-specific data (name, email) + Column("created_at", DateTime(timezone=True), nullable=False), + UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"), +) + +Index("ix_identities_user_id", identities_table.c.user_id) + + +# ============================================================================ +# REFRESH TOKENS TABLE (Authentication) +# ============================================================================ +refresh_tokens_table = Table( + "refresh_tokens", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("token_hash", String(64), nullable=False), # SHA256 hash + Column("family_id", String, nullable=False), # UUID - for theft detection + Column("expires_at", DateTime(timezone=True), nullable=False), + Column("created_at", DateTime(timezone=True), nullable=False), + Column("revoked_at", DateTime(timezone=True), nullable=True), +) + +Index("ix_refresh_tokens_user_id", refresh_tokens_table.c.user_id) +Index("ix_refresh_tokens_token_hash", refresh_tokens_table.c.token_hash) +Index("ix_refresh_tokens_family_id", refresh_tokens_table.c.family_id) diff --git a/server/pyproject.toml b/server/pyproject.toml index be6bb29..09609ca 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "rich>=14.2.0", "asyncpg>=0.31.0", "psycopg2-binary>=2.9.11", + "pyjwt>=2.11.0", ] [project.scripts] diff --git a/server/tests/unit/application/api/v1/routes/test_auth_state.py b/server/tests/unit/application/api/v1/routes/test_auth_state.py new file mode 100644 index 0000000..6915692 --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/test_auth_state.py @@ -0,0 +1,138 @@ +"""Unit tests for OAuth state token signing/verification.""" + +import time + + +from osa.application.api.v1.routes.auth import ( + _STATE_EXPIRY_SECONDS, + _create_signed_state, + _verify_signed_state, +) + + +class TestSignedStateCreation: + """Tests for _create_signed_state.""" + + def test_creates_state_with_redirect_uri(self): + """Should create a signed state containing the redirect URI.""" + secret = "test-secret-key-for-signing" + redirect_uri = "https://example.com/callback" + + state = _create_signed_state(secret, redirect_uri) + + # State should be format: payload.signature + assert "." in state + parts = state.split(".") + assert len(parts) == 2 + + def test_different_nonces_produce_different_states(self): + """Each state should have a unique nonce.""" + secret = "test-secret-key" + redirect_uri = "https://example.com" + + state1 = _create_signed_state(secret, redirect_uri) + state2 = _create_signed_state(secret, redirect_uri) + + assert state1 != state2 + + def test_state_is_url_safe(self): + """State should only contain URL-safe characters.""" + secret = "test-secret" + redirect_uri = "https://example.com/path?query=value" + + state = _create_signed_state(secret, redirect_uri) + + # URL-safe base64 uses only these characters + allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.") + assert all(c in allowed for c in state) + + +class TestSignedStateVerification: + """Tests for _verify_signed_state.""" + + def test_verifies_valid_state(self): + """Should return redirect_uri for valid state.""" + secret = "test-secret-key" + redirect_uri = "https://example.com/after-login" + + state = _create_signed_state(secret, redirect_uri) + result = _verify_signed_state(secret, state) + + assert result == redirect_uri + + def test_rejects_tampered_payload(self): + """Should reject state with tampered payload.""" + secret = "test-secret-key" + state = _create_signed_state(secret, "https://example.com") + + # Tamper with the payload (change a character) + parts = state.split(".") + tampered_payload = "x" + parts[0][1:] + tampered_state = f"{tampered_payload}.{parts[1]}" + + result = _verify_signed_state(secret, tampered_state) + assert result is None + + def test_rejects_tampered_signature(self): + """Should reject state with tampered signature.""" + secret = "test-secret-key" + state = _create_signed_state(secret, "https://example.com") + + # Tamper with the signature + parts = state.split(".") + tampered_sig = "x" + parts[1][1:] + tampered_state = f"{parts[0]}.{tampered_sig}" + + result = _verify_signed_state(secret, tampered_state) + assert result is None + + def test_rejects_wrong_secret(self): + """Should reject state signed with different secret.""" + state = _create_signed_state("secret-one", "https://example.com") + + result = _verify_signed_state("secret-two", state) + assert result is None + + def test_rejects_expired_state(self, monkeypatch): + """Should reject expired state.""" + secret = "test-secret-key" + state = _create_signed_state(secret, "https://example.com") + + # Fast-forward time past expiry + future_time = time.time() + _STATE_EXPIRY_SECONDS + 1 + monkeypatch.setattr(time, "time", lambda: future_time) + + result = _verify_signed_state(secret, state) + assert result is None + + def test_rejects_malformed_state(self): + """Should reject malformed state strings.""" + secret = "test-secret" + + # No dot separator + assert _verify_signed_state(secret, "nodot") is None + + # Empty parts + assert _verify_signed_state(secret, ".") is None + assert _verify_signed_state(secret, "payload.") is None + assert _verify_signed_state(secret, ".signature") is None + + # Too many parts + assert _verify_signed_state(secret, "a.b.c") is None + + # Invalid base64 + assert _verify_signed_state(secret, "!!!.???") is None + + def test_rejects_empty_state(self): + """Should reject empty state.""" + assert _verify_signed_state("secret", "") is None + + def test_handles_special_characters_in_redirect_uri(self): + """Should handle redirect URIs with special characters.""" + secret = "test-secret" + redirect_uri = "https://example.com/path?foo=bar&baz=qux#fragment" + + state = _create_signed_state(secret, redirect_uri) + result = _verify_signed_state(secret, state) + + assert result == redirect_uri diff --git a/server/tests/unit/domain/auth/__init__.py b/server/tests/unit/domain/auth/__init__.py new file mode 100644 index 0000000..0392972 --- /dev/null +++ b/server/tests/unit/domain/auth/__init__.py @@ -0,0 +1 @@ +"""Unit tests for auth domain.""" diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py new file mode 100644 index 0000000..bef7f98 --- /dev/null +++ b/server/tests/unit/domain/auth/test_auth_service.py @@ -0,0 +1,393 @@ +"""Unit tests for AuthService.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import IdentityId, RefreshTokenId, TokenFamilyId, UserId +from osa.domain.auth.port.identity_provider import IdentityInfo +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + + +def make_auth_service( + user_repo: AsyncMock | None = None, + identity_repo: AsyncMock | None = None, + refresh_token_repo: AsyncMock | None = None, + token_service: TokenService | None = None, + outbox: AsyncMock | None = None, +) -> AuthService: + """Create an AuthService with mocked dependencies.""" + if user_repo is None: + user_repo = AsyncMock() + if identity_repo is None: + identity_repo = AsyncMock() + if refresh_token_repo is None: + refresh_token_repo = AsyncMock() + if token_service is None: + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + token_service = TokenService(_config=config) + if outbox is None: + outbox = AsyncMock() + + return AuthService( + _user_repo=user_repo, + _identity_repo=identity_repo, + _refresh_token_repo=refresh_token_repo, + _token_service=token_service, + _outbox=outbox, + ) + + +def make_identity_provider(identity_info: IdentityInfo | None = None) -> MagicMock: + """Create a mock identity provider.""" + provider = MagicMock() + provider.provider_name = "orcid" + + if identity_info is None: + identity_info = IdentityInfo( + provider="orcid", + external_id="0000-0001-2345-6789", + display_name="Jane Doe", + email=None, + raw_data={"name": "Jane Doe", "orcid": "0000-0001-2345-6789"}, + ) + + provider.exchange_code = AsyncMock(return_value=identity_info) + provider.get_authorization_url = MagicMock(return_value="https://orcid.org/oauth/authorize?...") + return provider + + +class TestAuthServiceInitiateLogin: + """Tests for AuthService.initiate_login.""" + + @pytest.mark.asyncio + async def test_initiate_login_returns_authorization_url(self): + """initiate_login should return the provider's authorization URL.""" + service = make_auth_service() + provider = make_identity_provider() + provider.get_authorization_url.return_value = "https://orcid.org/oauth?state=abc" + + url = await service.initiate_login( + provider=provider, + state="test-state", + redirect_uri="http://localhost/callback", + ) + + assert url == "https://orcid.org/oauth?state=abc" + provider.get_authorization_url.assert_called_once_with( + "test-state", "http://localhost/callback" + ) + + +class TestAuthServiceCompleteOAuth: + """Tests for AuthService.complete_oauth.""" + + @pytest.mark.asyncio + async def test_complete_oauth_creates_new_user(self): + """complete_oauth should create user and identity for new user.""" + user_repo = AsyncMock() + user_repo.get.return_value = None # No existing user + + identity_repo = AsyncMock() + identity_repo.get_by_provider_and_external_id.return_value = None # No existing identity + + refresh_token_repo = AsyncMock() + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + provider = make_identity_provider() + + user, identity, access_token, refresh_token = await service.complete_oauth( + provider=provider, + code="auth-code", + redirect_uri="http://localhost/callback", + ) + + # Should create user and identity + user_repo.save.assert_called_once() + identity_repo.save.assert_called_once() + refresh_token_repo.save.assert_called_once() + + # Should return valid data + assert user.display_name == "Jane Doe" + assert identity.provider == "orcid" + assert identity.external_id == "0000-0001-2345-6789" + assert isinstance(access_token, str) + assert isinstance(refresh_token, str) + + @pytest.mark.asyncio + async def test_complete_oauth_returns_existing_user(self): + """complete_oauth should return existing user if identity exists.""" + existing_user = User( + id=UserId(uuid4()), + display_name="Existing User", + created_at=datetime.now(UTC), + updated_at=None, + ) + existing_identity = Identity( + id=IdentityId(uuid4()), + user_id=existing_user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + user_repo = AsyncMock() + user_repo.get.return_value = existing_user + + identity_repo = AsyncMock() + identity_repo.get_by_provider_and_external_id.return_value = existing_identity + + refresh_token_repo = AsyncMock() + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + provider = make_identity_provider() + + user, identity, _, _ = await service.complete_oauth( + provider=provider, + code="auth-code", + redirect_uri="http://localhost/callback", + ) + + # Should NOT create new user/identity + user_repo.save.assert_not_called() + identity_repo.save.assert_not_called() + + # Should return existing user + assert user.id == existing_user.id + assert identity.id == existing_identity.id + + +class TestAuthServiceRefreshTokens: + """Tests for AuthService.refresh_tokens.""" + + @pytest.mark.asyncio + async def test_refresh_tokens_issues_new_tokens(self): + """refresh_tokens should issue new access and refresh tokens.""" + user = User( + id=UserId(uuid4()), + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + old_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user.id, + token_hash="old-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + user_repo = AsyncMock() + user_repo.get.return_value = user + + identity_repo = AsyncMock() + identity_repo.get_by_user_id.return_value = [identity] + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = old_token + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + + returned_user, access_token, new_refresh_token = await service.refresh_tokens( + "raw-refresh-token" + ) + + # Should save new refresh token + assert refresh_token_repo.save.call_count == 2 # Once for revoking old, once for new + + # Should return valid data + assert returned_user.id == user.id + assert isinstance(access_token, str) + assert isinstance(new_refresh_token, str) + + @pytest.mark.asyncio + async def test_refresh_tokens_revokes_old_token(self): + """refresh_tokens should revoke the old refresh token.""" + user = User( + id=UserId(uuid4()), + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + old_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user.id, + token_hash="old-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + user_repo = AsyncMock() + user_repo.get.return_value = user + + identity_repo = AsyncMock() + identity_repo.get_by_user_id.return_value = [identity] + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = old_token + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + + await service.refresh_tokens("raw-refresh-token") + + # The old token should be revoked + assert old_token.is_revoked is True + + @pytest.mark.asyncio + async def test_refresh_tokens_rejects_invalid_token(self): + """refresh_tokens should raise for unknown refresh token.""" + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = None + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("invalid-token") + + assert exc_info.value.code == "invalid_refresh_token" + + @pytest.mark.asyncio + async def test_refresh_tokens_detects_reuse_and_revokes_family(self): + """refresh_tokens should revoke entire family if revoked token is reused.""" + user_id = UserId(uuid4()) + family_id = TokenFamilyId(uuid4()) + + # Token that was already revoked (potential theft) + revoked_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user_id, + token_hash="revoked-hash", + family_id=family_id, + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=datetime.now(UTC) - timedelta(hours=1), # Already revoked + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = revoked_token + refresh_token_repo.revoke_family.return_value = 3 # 3 tokens revoked + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("stolen-token") + + assert exc_info.value.code == "token_family_revoked" + + # Should revoke entire family + refresh_token_repo.revoke_family.assert_called_once_with(family_id) + + @pytest.mark.asyncio + async def test_refresh_tokens_rejects_expired_token(self): + """refresh_tokens should raise for expired refresh token.""" + expired_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="expired-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired + created_at=datetime.now(UTC) - timedelta(days=8), + revoked_at=None, + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = expired_token + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("expired-token") + + assert exc_info.value.code == "refresh_token_expired" + + +class TestAuthServiceLogout: + """Tests for AuthService.logout.""" + + @pytest.mark.asyncio + async def test_logout_revokes_token_family(self): + """logout should revoke the entire token family.""" + family_id = TokenFamilyId(uuid4()) + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="token-hash", + family_id=family_id, + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = token + refresh_token_repo.revoke_family.return_value = 1 + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + result = await service.logout("raw-refresh-token") + + assert result is True + refresh_token_repo.revoke_family.assert_called_once_with(family_id) + + @pytest.mark.asyncio + async def test_logout_succeeds_for_unknown_token(self): + """logout should succeed even if token is not found.""" + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = None + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + result = await service.logout("unknown-token") + + assert result is True + refresh_token_repo.revoke_family.assert_not_called() diff --git a/server/tests/unit/domain/auth/test_refresh_token.py b/server/tests/unit/domain/auth/test_refresh_token.py new file mode 100644 index 0000000..6179fbf --- /dev/null +++ b/server/tests/unit/domain/auth/test_refresh_token.py @@ -0,0 +1,227 @@ +"""Unit tests for RefreshToken entity.""" + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + + +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId + + +class TestRefreshTokenCreate: + """Tests for RefreshToken.create factory method.""" + + def test_create_sets_all_fields(self): + """create should set all required fields.""" + user_id = UserId(uuid4()) + token_hash = "a" * 64 + family_id = TokenFamilyId(uuid4()) + + token = RefreshToken.create( + user_id=user_id, + token_hash=token_hash, + family_id=family_id, + expires_in_days=7, + ) + + assert token.user_id == user_id + assert token.token_hash == token_hash + assert token.family_id == family_id + assert token.revoked_at is None + + def test_create_generates_id(self): + """create should generate a unique ID.""" + user_id = UserId(uuid4()) + + token1 = RefreshToken.create(user_id, "a" * 64, TokenFamilyId(uuid4())) + token2 = RefreshToken.create(user_id, "b" * 64, TokenFamilyId(uuid4())) + + assert token1.id != token2.id + + def test_create_sets_expiry_in_future(self): + """create should set expires_at in the future.""" + user_id = UserId(uuid4()) + now = datetime.now(UTC) + + token = RefreshToken.create( + user_id=user_id, + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_in_days=7, + ) + + assert token.expires_at > now + # Should be roughly 7 days from now (allowing small margin) + expected = now + timedelta(days=7) + assert abs((token.expires_at - expected).total_seconds()) < 1 + + def test_create_sets_created_at(self): + """create should set created_at to current time.""" + now = datetime.now(UTC) + + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + assert abs((token.created_at - now).total_seconds()) < 1 + + +class TestRefreshTokenIsValid: + """Tests for RefreshToken.is_valid property.""" + + def make_token( + self, + expires_at: datetime | None = None, + revoked_at: datetime | None = None, + ) -> RefreshToken: + """Create a token with specified expiry and revocation.""" + if expires_at is None: + expires_at = datetime.now(UTC) + timedelta(days=7) + + return RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=expires_at, + created_at=datetime.now(UTC), + revoked_at=revoked_at, + ) + + def test_is_valid_true_for_fresh_token(self): + """is_valid should be True for non-expired, non-revoked token.""" + token = self.make_token() + + assert token.is_valid is True + + def test_is_valid_false_when_expired(self): + """is_valid should be False when token is expired.""" + expired_at = datetime.now(UTC) - timedelta(hours=1) + token = self.make_token(expires_at=expired_at) + + assert token.is_valid is False + + def test_is_valid_false_when_revoked(self): + """is_valid should be False when token is revoked.""" + token = self.make_token(revoked_at=datetime.now(UTC)) + + assert token.is_valid is False + + def test_is_valid_false_when_both_expired_and_revoked(self): + """is_valid should be False when both expired and revoked.""" + token = self.make_token( + expires_at=datetime.now(UTC) - timedelta(hours=1), + revoked_at=datetime.now(UTC) - timedelta(hours=2), + ) + + assert token.is_valid is False + + +class TestRefreshTokenIsRevoked: + """Tests for RefreshToken.is_revoked property.""" + + def test_is_revoked_false_initially(self): + """is_revoked should be False when revoked_at is None.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + assert token.is_revoked is False + + def test_is_revoked_true_when_set(self): + """is_revoked should be True when revoked_at is set.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=datetime.now(UTC), + ) + + assert token.is_revoked is True + + +class TestRefreshTokenIsExpired: + """Tests for RefreshToken.is_expired property.""" + + def test_is_expired_false_for_future_expiry(self): + """is_expired should be False when expires_at is in the future.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + assert token.is_expired is False + + def test_is_expired_true_for_past_expiry(self): + """is_expired should be True when expires_at is in the past.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) - timedelta(hours=1), + created_at=datetime.now(UTC) - timedelta(days=8), + revoked_at=None, + ) + + assert token.is_expired is True + + +class TestRefreshTokenRevoke: + """Tests for RefreshToken.revoke method.""" + + def test_revoke_sets_revoked_at(self): + """revoke should set revoked_at to current time.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + assert token.revoked_at is None + + now = datetime.now(UTC) + token.revoke() + + assert token.revoked_at is not None + assert abs((token.revoked_at - now).total_seconds()) < 1 + + def test_revoke_idempotent(self): + """revoke should not change revoked_at if already revoked.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + token.revoke() + first_revoked_at = token.revoked_at + + # Second revoke should not change the timestamp + token.revoke() + + assert token.revoked_at == first_revoked_at + + def test_revoke_makes_token_invalid(self): + """revoke should make is_valid return False.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + assert token.is_valid is True + + token.revoke() + + assert token.is_valid is False diff --git a/server/tests/unit/domain/auth/test_token_service.py b/server/tests/unit/domain/auth/test_token_service.py new file mode 100644 index 0000000..587e008 --- /dev/null +++ b/server/tests/unit/domain/auth/test_token_service.py @@ -0,0 +1,243 @@ +"""Unit tests for TokenService JWT creation and validation.""" + +import time +from uuid import uuid4 + +import jwt +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.value import UserId +from osa.domain.auth.service.token import TokenService + + +class TestTokenServiceAccessToken: + """Tests for JWT access token creation and validation.""" + + def make_service(self, secret: str = "test-secret-key-256-bits-long-xx") -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret=secret, + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + def test_create_access_token_returns_valid_jwt(self): + """create_access_token should return a decodable JWT.""" + service = self.make_service() + user_id = UserId(uuid4()) + orcid_id = "0000-0001-2345-6789" + + token = service.create_access_token(user_id, orcid_id) + + # Should be decodable + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert payload["sub"] == str(user_id) + assert payload["orcid_id"] == orcid_id + assert payload["aud"] == "authenticated" + + def test_create_access_token_includes_expiry(self): + """create_access_token should set exp claim.""" + service = self.make_service() + user_id = UserId(uuid4()) + + token = service.create_access_token(user_id, "0000-0001-2345-6789") + + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert "exp" in payload + assert "iat" in payload + # Expiry should be ~60 minutes from now + assert payload["exp"] > payload["iat"] + assert payload["exp"] - payload["iat"] == 60 * 60 # 60 minutes in seconds + + def test_create_access_token_includes_jti(self): + """create_access_token should include unique jti claim.""" + service = self.make_service() + user_id = UserId(uuid4()) + + token1 = service.create_access_token(user_id, "0000-0001-2345-6789") + token2 = service.create_access_token(user_id, "0000-0001-2345-6789") + + payload1 = jwt.decode( + token1, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + payload2 = jwt.decode( + token2, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + + assert "jti" in payload1 + assert "jti" in payload2 + assert payload1["jti"] != payload2["jti"] + + def test_create_access_token_with_additional_claims(self): + """create_access_token should include additional claims if provided.""" + service = self.make_service() + user_id = UserId(uuid4()) + + token = service.create_access_token( + user_id, + "0000-0001-2345-6789", + additional_claims={"custom": "value"}, + ) + + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert payload["custom"] == "value" + + def test_validate_access_token_returns_payload(self): + """validate_access_token should return decoded payload.""" + service = self.make_service() + user_id = UserId(uuid4()) + orcid_id = "0000-0001-2345-6789" + + token = service.create_access_token(user_id, orcid_id) + payload = service.validate_access_token(token) + + assert payload["sub"] == str(user_id) + assert payload["orcid_id"] == orcid_id + + def test_validate_access_token_rejects_invalid_token(self): + """validate_access_token should raise for invalid tokens.""" + service = self.make_service() + + with pytest.raises(jwt.InvalidTokenError): + service.validate_access_token("invalid-token") + + def test_validate_access_token_rejects_wrong_secret(self): + """validate_access_token should reject tokens signed with wrong secret.""" + service1 = self.make_service(secret="secret-one-that-is-long-enough") + service2 = self.make_service(secret="secret-two-that-is-long-enough") + + token = service1.create_access_token(UserId(uuid4()), "0000-0001-2345-6789") + + with pytest.raises(jwt.InvalidTokenError): + service2.validate_access_token(token) + + def test_validate_access_token_rejects_expired_token(self): + """validate_access_token should reject expired tokens.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=0, # Immediate expiry + refresh_token_expire_days=7, + ) + service = TokenService(_config=config) + + # Create token that's already expired + token = service.create_access_token(UserId(uuid4()), "0000-0001-2345-6789") + + # Small delay to ensure expiry + time.sleep(0.1) + + with pytest.raises(jwt.ExpiredSignatureError): + service.validate_access_token(token) + + +class TestTokenServiceRefreshToken: + """Tests for opaque refresh token creation.""" + + def make_service(self) -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret="test-secret", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + def test_create_refresh_token_returns_tuple(self): + """create_refresh_token should return (raw_token, token_hash).""" + service = self.make_service() + + raw_token, token_hash = service.create_refresh_token() + + assert isinstance(raw_token, str) + assert isinstance(token_hash, str) + assert len(raw_token) > 0 + assert len(token_hash) == 64 # SHA256 hex = 64 chars + + def test_create_refresh_token_unique_each_time(self): + """create_refresh_token should generate unique tokens.""" + service = self.make_service() + + raw1, hash1 = service.create_refresh_token() + raw2, hash2 = service.create_refresh_token() + + assert raw1 != raw2 + assert hash1 != hash2 + + def test_hash_token_consistent(self): + """hash_token should produce consistent hashes.""" + raw_token = "test-token-value" + + hash1 = TokenService.hash_token(raw_token) + hash2 = TokenService.hash_token(raw_token) + + assert hash1 == hash2 + assert len(hash1) == 64 + + def test_hash_token_different_for_different_tokens(self): + """hash_token should produce different hashes for different tokens.""" + hash1 = TokenService.hash_token("token-one") + hash2 = TokenService.hash_token("token-two") + + assert hash1 != hash2 + + def test_created_refresh_token_hash_matches(self): + """The hash from create_refresh_token should match hash_token(raw).""" + service = self.make_service() + + raw_token, token_hash = service.create_refresh_token() + + assert TokenService.hash_token(raw_token) == token_hash + + +class TestTokenServiceProperties: + """Tests for TokenService property accessors.""" + + def test_access_token_expire_seconds(self): + """access_token_expire_seconds should convert minutes to seconds.""" + config = JwtConfig( + secret="test", + algorithm="HS256", + access_token_expire_minutes=30, + refresh_token_expire_days=7, + ) + service = TokenService(_config=config) + + assert service.access_token_expire_seconds == 30 * 60 + + def test_refresh_token_expire_days(self): + """refresh_token_expire_days should return configured value.""" + config = JwtConfig( + secret="test", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=14, + ) + service = TokenService(_config=config) + + assert service.refresh_token_expire_days == 14 diff --git a/server/uv.lock b/server/uv.lock index e888c64..cdf84bd 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1588,6 +1588,7 @@ dependencies = [ { name = "psycopg2-binary", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", extra = ["email"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic-settings", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyjwt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "rich", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sentence-transformers", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sqlalchemy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1623,6 +1624,7 @@ requires-dist = [ { name = "psycopg2-binary", specifier = ">=2.9.11" }, { name = "pydantic", extras = ["email"], specifier = ">=2.12.4" }, { name = "pydantic-settings", specifier = ">=2.12.0" }, + { name = "pyjwt", specifier = ">=2.11.0" }, { name = "rich", specifier = ">=14.2.0" }, { name = "sentence-transformers", specifier = ">=5.2.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, @@ -1948,6 +1950,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, +] + [[package]] name = "pypika" version = "0.50.0" diff --git a/web/src/app/auth/callback/page.tsx b/web/src/app/auth/callback/page.tsx new file mode 100644 index 0000000..0d1b05a --- /dev/null +++ b/web/src/app/auth/callback/page.tsx @@ -0,0 +1,72 @@ +'use client'; + +import { useEffect, useState } from 'react'; +import { useRouter } from 'next/navigation'; +import { OSAClient, parseAuthCallback } from '@/lib/sdk'; + +export default function AuthCallbackPage() { + const router = useRouter(); + const [error, setError] = useState(null); + + useEffect(() => { + // Check for error in URL search params + const searchParams = new URLSearchParams(window.location.search); + const urlError = searchParams.get('error'); + const errorDescription = searchParams.get('error_description'); + + if (urlError) { + setError(errorDescription || urlError); + return; + } + + // Check for auth data in hash + const hash = window.location.hash; + if (!hash || !hash.includes('auth=')) { + setError('No authentication data received'); + return; + } + + // Parse and store auth data + const params = parseAuthCallback(hash); + if (!params) { + setError('Failed to parse authentication data'); + return; + } + + // Store in client + const client = new OSAClient({ baseUrl: '/api/v1' }); + client.handleAuthCallback(hash); + + // Redirect to home (or wherever user came from) + router.push('/'); + }, [router]); + + if (error) { + return ( +
+

Authentication Error

+

{error}

+ + Return to Home + +
+ ); + } + + return ( +
+

Completing sign in...

+
+ ); +} diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index dc5ed1a..15f9a7d 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -1,6 +1,7 @@ import type { Metadata } from 'next'; import { Geist, Geist_Mono } from 'next/font/google'; import './globals.css'; +import { AuthProvider } from '@/components/auth/AuthProvider'; import { Header } from '@/components/layout/Header'; import { Footer } from '@/components/layout/Footer'; @@ -22,6 +23,9 @@ export const metadata: Metadata = { keywords: ['biology', 'genomics', 'GEO', 'research', 'scientific data', 'semantic search'], }; +// API URL: use env var for dev (different port), relative path for prod (reverse proxy) +const apiBaseUrl = process.env.NEXT_PUBLIC_API_URL || '/api/v1'; + export default function RootLayout({ children, }: Readonly<{ @@ -30,9 +34,11 @@ export default function RootLayout({ return ( -
- {children} -