diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc180ea..414bc0d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,3 +35,20 @@ repos: types: [python] pass_filenames: false stages: [pre-commit] + + - id: web-lint + name: web lint + entry: pnpm --dir web lint + language: system + files: ^web/ + types_or: [javascript, jsx, ts, tsx, css] + pass_filenames: false + + - id: web-build + name: web build + entry: pnpm --dir web build + language: system + files: ^web/ + types_or: [javascript, jsx, ts, tsx, css, json] + pass_filenames: false + stages: [pre-commit] diff --git a/Justfile b/Justfile index 0add92f..d750323 100644 --- a/Justfile +++ b/Justfile @@ -20,30 +20,10 @@ up-attached: down: docker compose -f deploy/docker-compose.yml down -# View logs from all services -logs: - docker compose -f deploy/docker-compose.yml logs -f - -# View logs from a specific service -logs-service service: +# View logs for a service (e.g., just logs server, just logs db) +logs service: docker compose -f deploy/docker-compose.yml logs -f {{service}} -# View server logs -server-logs: - docker compose -f deploy/docker-compose.yml logs -f server - -# View last N lines of server logs (default: 100) -server-logs-tail lines="100": - docker compose -f deploy/docker-compose.yml logs --tail {{lines}} server - -# View server logs with timestamps -server-logs-time: - docker compose -f deploy/docker-compose.yml logs -f -t server - -# View server logs since a time (e.g., "10m", "1h", "2024-01-01") -server-logs-since since: - docker compose -f deploy/docker-compose.yml logs -f --since {{since}} server - # Shell into the server container server-shell: docker compose -f deploy/docker-compose.yml exec server bash @@ -106,10 +86,6 @@ db-up: db-down: docker compose -f deploy/docker-compose.yml stop db -# View database logs -db-logs: - docker compose -f deploy/docker-compose.yml logs -f db - # Connect to PostgreSQL db-connect: docker compose -f deploy/docker-compose.yml exec db psql -U postgres -d osa diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index 32fa427..94a9c2a 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -7,15 +7,19 @@ services: context: ../server dockerfile: Dockerfile target: builder + ports: + - "8000:8000" volumes: - ../server:/app - /app/.venv environment: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data + OSA_CONFIG_FILE: /app/osa.yaml OSA_LOGGING__LEVEL: ${LOG_LEVEL:-DEBUG} WATCHFILES_FORCE_POLLING: "true" - command: uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload + entrypoint: [] + command: ["sh", "-c", "/app/.venv/bin/alembic upgrade head && /app/.venv/bin/uvicorn osa.application.api.rest.app:app --host 0.0.0.0 --port 8000 --reload"] healthcheck: test: ["CMD", "curl", "--fail", "http://localhost:8000/api/v1/health"] interval: 10s @@ -36,4 +40,5 @@ services: API_URL: http://server:8000 ports: - "3000:3000" - command: pnpm dev + entrypoint: [] + command: ["pnpm", "dev"] diff --git a/server/.env.example b/server/.env.example new file mode 100644 index 0000000..35bfe0a --- /dev/null +++ b/server/.env.example @@ -0,0 +1,76 @@ +# OSA Server Configuration - Secrets and Environment-Specific Values +# Copy this file to .env and fill in your values +# +# Relationship with osa.yaml: +# - .env: SECRETS (credentials, JWT secrets) - never commit to git +# - osa.yaml: APPLICATION CONFIG (sources, indexes) - can be version-controlled +# +# Priority (highest to lowest): +# 1. Environment variables +# 2. .env file +# 3. OSA_CONFIG_FILE (osa.yaml) +# +# Auth secrets MUST go in .env (not osa.yaml) for security + +# ============================================================================= +# ORCiD OAuth Configuration +# ============================================================================= +# Register at: https://sandbox.orcid.org/developer-tools (sandbox) +# Register at: https://orcid.org/developer-tools (production) + +# ORCiD OAuth client ID (e.g., APP-XXXXXXXXXXXX) +OSA_AUTH__ORCID__CLIENT_ID= + +# ORCiD OAuth client secret +OSA_AUTH__ORCID__CLIENT_SECRET= + +# Use ORCiD sandbox for development (default: true) +# Set to false for production +OSA_AUTH__ORCID__SANDBOX=true + +# ============================================================================= +# JWT Configuration +# ============================================================================= + +# JWT signing secret - minimum 32 characters required +# Generate with: openssl rand -hex 32 +OSA_AUTH__JWT__SECRET= + +# Access token expiry in minutes (default: 60) +# OSA_AUTH__JWT__ACCESS_TOKEN_EXPIRE_MINUTES=60 + +# Refresh token expiry in days (default: 7) +# OSA_AUTH__JWT__REFRESH_TOKEN_EXPIRE_DAYS=7 + +# ============================================================================= +# OAuth Callback URL +# ============================================================================= +# Must match the redirect URI registered in ORCiD developer tools + +# For development (default derives from request URL) +# OSA_AUTH__CALLBACK_URL=http://localhost:8000/api/v1/auth/callback + +# For production +# OSA_AUTH__CALLBACK_URL=https://your-domain.com/api/v1/auth/callback + +# ============================================================================= +# Frontend URL +# ============================================================================= +# URL of the frontend application for OAuth redirects + +# For development +OSA_FRONTEND__URL=http://localhost:3000 + +# For production +# OSA_FRONTEND__URL=https://your-domain.com + +# ============================================================================= +# Database Configuration +# ============================================================================= +# Default: SQLite in ~/.local/share/osa/osa.db + +# PostgreSQL example: +# OSA_DATABASE__URL=postgresql+asyncpg://user:password@localhost:5432/osa + +# Echo SQL queries (for debugging) +# OSA_DATABASE__ECHO=false diff --git a/server/Dockerfile b/server/Dockerfile index 6b0d381..ddd05b9 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -3,6 +3,9 @@ # Stage 1: Builder FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS builder +# Install curl for healthchecks in dev mode +RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/* + WORKDIR /app # Enable bytecode compilation for faster startup diff --git a/server/migrations/versions/add_auth_tables.py b/server/migrations/versions/add_auth_tables.py new file mode 100644 index 0000000..b04ac0b --- /dev/null +++ b/server/migrations/versions/add_auth_tables.py @@ -0,0 +1,89 @@ +"""add_auth_tables + +Add users, identities, and refresh_tokens tables for authentication. + +Revision ID: add_auth_tables +Revises: add_worker_columns +Create Date: 2026-02-04 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_auth_tables" +down_revision: Union[str, Sequence[str], None] = "add_worker_columns" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add authentication tables.""" + # USERS TABLE + op.create_table( + "users", + sa.Column("id", sa.String(), nullable=False), + sa.Column("display_name", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # IDENTITIES TABLE + op.create_table( + "identities", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("provider", sa.String(50), nullable=False), + sa.Column("external_id", sa.String(255), nullable=False), + sa.Column("metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + sa.UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"), + ) + op.create_index("ix_identities_user_id", "identities", ["user_id"]) + + # REFRESH TOKENS TABLE + op.create_table( + "refresh_tokens", + sa.Column("id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.Column("token_hash", sa.String(64), nullable=False), + sa.Column("family_id", sa.String(), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("revoked_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ondelete="CASCADE", + ), + ) + op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"]) + op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"]) + op.create_index("ix_refresh_tokens_family_id", "refresh_tokens", ["family_id"]) + + +def downgrade() -> None: + """Remove authentication tables.""" + # REFRESH TOKENS + op.drop_index("ix_refresh_tokens_family_id", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_token_hash", table_name="refresh_tokens") + op.drop_index("ix_refresh_tokens_user_id", table_name="refresh_tokens") + op.drop_table("refresh_tokens") + + # IDENTITIES + op.drop_index("ix_identities_user_id", table_name="identities") + op.drop_table("identities") + + # USERS + op.drop_table("users") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index f997a6f..735e836 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -6,7 +6,7 @@ from fastapi.responses import JSONResponse from osa.application.api.v1.errors import map_osa_error -from osa.application.api.v1.routes import events, health, records, search, stats, validation +from osa.application.api.v1.routes import auth, events, health, records, search, stats, validation from osa.application.di import create_container from osa.config import Config, configure_logging from osa.domain.shared.error import OSAError @@ -32,7 +32,8 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: """Create FastAPI application.""" - config = Config() + # Pydantic Settings populates from env vars at runtime + config = Config() # type: ignore[call-arg] # Configure logging early configure_logging(config.logging) @@ -58,6 +59,7 @@ def create_app() -> FastAPI: # Register v1 routes with /api/v1 prefix app_instance.include_router(health.router, prefix="/api/v1") + app_instance.include_router(auth.router, prefix="/api/v1") app_instance.include_router(events.router, prefix="/api/v1") app_instance.include_router(records.router, prefix="/api/v1") app_instance.include_router(search.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/routes/auth.py b/server/osa/application/api/v1/routes/auth.py new file mode 100644 index 0000000..de22b0d --- /dev/null +++ b/server/osa/application/api/v1/routes/auth.py @@ -0,0 +1,277 @@ +"""Authentication routes for OAuth login flow.""" + +import logging +from typing import Annotated +from urllib.parse import urlencode + +from dishka import FromDishka +from dishka.integrations.fastapi import DishkaRoute +from fastapi import APIRouter, HTTPException, Query, Request, Response +from fastapi.responses import RedirectResponse +from pydantic import BaseModel + +from osa.config import Config +from osa.domain.auth.command.login import ( + CompleteOAuth, + CompleteOAuthHandler, + InitiateLogin, + InitiateLoginHandler, +) +from osa.domain.auth.command.token import ( + Logout, + LogoutHandler, + RefreshTokens, + RefreshTokensHandler, +) +from osa.domain.auth.model.value import CurrentUser +from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth", tags=["Authentication"], route_class=DishkaRoute) + + +class RefreshTokenRequest(BaseModel): + """Request body for token refresh.""" + + refresh_token: str + + +class LogoutRequest(BaseModel): + """Request body for logout.""" + + refresh_token: str + + +class TokenResponse(BaseModel): + """Response containing tokens.""" + + access_token: str + refresh_token: str + token_type: str = "Bearer" + expires_in: int + + +class LogoutResponse(BaseModel): + """Response for logout.""" + + success: bool + + +class UserResponse(BaseModel): + """Response containing user info.""" + + id: str + display_name: str | None + provider: str + external_id: str + + +@router.get("/login") +async def initiate_login( + request: Request, + config: FromDishka[Config], + handler: FromDishka[InitiateLoginHandler], + registry: FromDishka[ProviderRegistry], + provider: Annotated[str, Query()], + redirect_uri: Annotated[str | None, Query()] = None, +) -> Response: + """Initiate OAuth login flow. + + Redirects to identity provider's authorization page. + """ + # Validate provider is configured + if not registry.is_available(provider): + available = registry.available_providers() + raise HTTPException( + status_code=400, + detail={ + "code": "unknown_provider", + "message": f"Unknown provider: {provider}. Available: {', '.join(available) or 'none'}", + }, + ) + + # Determine callback URL + callback_url = config.auth.callback_url + if not callback_url: + callback_url = str(request.url_for("handle_oauth_callback")) + + # Determine final redirect URI + final_redirect = redirect_uri or config.frontend.url + + result = await handler.run( + InitiateLogin( + callback_url=callback_url, + final_redirect_uri=final_redirect, + provider=provider, + ) + ) + + logger.info("OAuth login initiated for provider=%s, redirecting to IdP", provider) + return RedirectResponse(url=result.authorization_url, status_code=302) + + +@router.get("/callback") +async def handle_oauth_callback( + request: Request, + config: FromDishka[Config], + handler: FromDishka[CompleteOAuthHandler], + token_service: FromDishka[TokenService], + code: Annotated[str | None, Query()] = None, + state: Annotated[str | None, Query()] = None, + error: Annotated[str | None, Query()] = None, + error_description: Annotated[str | None, Query()] = None, +) -> Response: + """Handle OAuth callback from identity provider. + + Exchanges authorization code for tokens and redirects to frontend. + """ + frontend_url = config.frontend.url + + # Check for OAuth errors + if error: + logger.warning("OAuth error: %s - %s", error, error_description) + error_params = urlencode( + { + "error": error, + "error_description": error_description or "Authentication failed", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + # Validate signed state token + if not state: + logger.warning("OAuth state missing") + error_params = urlencode( + { + "error": "oauth_state_missing", + "error_description": "Missing state parameter", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + state_data = token_service.verify_oauth_state(state) + if state_data is None: + logger.warning("OAuth state invalid or expired") + error_params = urlencode( + { + "error": "oauth_state_invalid", + "error_description": "Invalid or expired state parameter", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + final_redirect, provider = state_data + + if not code: + logger.warning("OAuth callback missing code") + error_params = urlencode( + { + "error": "missing_code", + "error_description": "Authorization code not provided", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + try: + # Determine callback URL (must match what was used in authorization) + callback_url = config.auth.callback_url + if not callback_url: + callback_url = str(request.url_for("handle_oauth_callback")) + + # Complete OAuth flow via handler + result = await handler.run( + CompleteOAuth( + code=code, + callback_url=callback_url, + provider=provider, + ) + ) + + # Build redirect URL with tokens in fragment + token_params = urlencode( + { + "access_token": result.access_token, + "refresh_token": result.refresh_token, + "token_type": "Bearer", + "expires_in": result.expires_in, + "user_id": result.user_id, + "display_name": result.display_name or "", + "provider": result.provider, + "external_id": result.external_id, + } + ) + + redirect_url = f"{final_redirect}#auth={token_params}" + logger.info( + "OAuth complete, user authenticated: user_id=%s, provider=%s", result.user_id, provider + ) + return RedirectResponse(url=redirect_url, status_code=302) + + except Exception as e: + logger.exception("OAuth callback failed: %s", e) + error_params = urlencode( + { + "error": "oauth_error", + "error_description": "Authentication failed. Please try again.", + } + ) + return RedirectResponse(url=f"{frontend_url}/auth/error?{error_params}") + + +@router.post("/refresh", response_model=TokenResponse) +async def refresh_token( + body: RefreshTokenRequest, + handler: FromDishka[RefreshTokensHandler], +) -> TokenResponse: + """Refresh access token using refresh token.""" + try: + result = await handler.run(RefreshTokens(refresh_token=body.refresh_token)) + return TokenResponse( + access_token=result.access_token, + refresh_token=result.refresh_token, + expires_in=result.expires_in, + ) + except InvalidStateError as e: + raise HTTPException( + status_code=401, + detail={ + "code": e.code, + "message": e.message, + }, + ) from e + + +@router.post("/logout", response_model=LogoutResponse) +async def logout( + body: LogoutRequest, + handler: FromDishka[LogoutHandler], +) -> LogoutResponse: + """Logout and revoke refresh token.""" + result = await handler.run(Logout(refresh_token=body.refresh_token)) + return LogoutResponse(success=result.success) + + +@router.get("/me", response_model=UserResponse) +async def get_me( + current_user: FromDishka[CurrentUser], + auth_service: FromDishka[AuthService], +) -> UserResponse: + """Get current authenticated user information.""" + user = await auth_service.get_user_by_id(current_user.user_id) + + if user is None: + raise HTTPException( + status_code=401, + detail={"code": "user_not_found", "message": "User not found"}, + ) + + return UserResponse( + id=str(user.id), + display_name=user.display_name, + provider=current_user.identity.provider, + external_id=current_user.identity.external_id, + ) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 9f70b33..a7b0c7c 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -2,8 +2,10 @@ from osa.cli.util.paths import OSAPaths from osa.config import Config +from osa.domain.auth.util.di import AuthProvider from osa.domain.deposition.util.di import DepositionProvider from osa.domain.validation.util.di import ValidationProvider +from osa.infrastructure.auth import AuthInfraProvider from osa.infrastructure.event.di import EventProvider from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.source.di import SourceProvider @@ -13,7 +15,8 @@ def create_container() -> AsyncContainer: - config = Config() + # Pydantic Settings populates from env vars at runtime + config = Config() # type: ignore[call-arg] # OSAPaths reads OSA_DATA_DIR from environment automatically paths = OSAPaths() @@ -26,6 +29,8 @@ def create_container() -> AsyncContainer: EventProvider(), DepositionProvider(), ValidationProvider(), + AuthProvider(), + AuthInfraProvider(), context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class ) diff --git a/server/osa/cli/util/daemon.py b/server/osa/cli/util/daemon.py index e7d99e2..1ff80fe 100644 --- a/server/osa/cli/util/daemon.py +++ b/server/osa/cli/util/daemon.py @@ -123,7 +123,8 @@ def start( os.environ["OSA_CONFIG_FILE"] = config_file try: - app_config = Config() + # Pydantic Settings populates from env vars at runtime + app_config = Config() # type: ignore[call-arg] except ValidationError as e: details = [] for err in e.errors(): diff --git a/server/osa/config.py b/server/osa/config.py index feebefd..feab7ea 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -150,6 +150,51 @@ class WorkerConfig(BaseModel): batch_size: int = 100 # Maximum events to fetch per poll cycle +# ============================================================================= +# Authentication Configuration +# ============================================================================= + + +class OrcidConfig(BaseModel): + """ORCiD OAuth configuration.""" + + client_id: str = "" + client_secret: str = "" + sandbox: bool = True # Use sandbox.orcid.org by default + + @property + def base_url(self) -> str: + """Get base URL for ORCiD API based on sandbox setting.""" + return "https://sandbox.orcid.org" if self.sandbox else "https://orcid.org" + + +class JwtConfig(BaseModel): + """JWT configuration.""" + + secret: str # Required - set via OSA_AUTH__JWT__SECRET + algorithm: str = "HS256" + access_token_expire_minutes: int = 60 # 1 hour + refresh_token_expire_days: int = 7 + + @model_validator(mode="after") + def validate_secret_length(self) -> Self: + """Ensure JWT secret has sufficient length.""" + if len(self.secret) < 32: + raise ValueError( + "JWT secret must be at least 32 characters for security. " + "Generate with: openssl rand -hex 32" + ) + return self + + +class AuthConfig(BaseModel): + """Authentication configuration.""" + + orcid: OrcidConfig = OrcidConfig() + jwt: JwtConfig # Required - no default, must be configured via env vars + callback_url: str = "" # Full callback URL (e.g., https://myarchive.org/api/v1/auth/callback) + + class Config(BaseSettings): # These are BaseModel, so env_nested_delimiter handles their env vars server: Server = Server() @@ -157,6 +202,7 @@ class Config(BaseSettings): database: DatabaseConfig = DatabaseConfig() logging: LoggingConfig = LoggingConfig() worker: WorkerConfig = WorkerConfig() # Background worker settings + auth: AuthConfig # Required - set via OSA_AUTH__JWT__SECRET env var indexes: list[IndexConfig] = [] # list of index configs sources: list[SourceConfig] = [] # list of source configs diff --git a/server/osa/domain/auth/command/__init__.py b/server/osa/domain/auth/command/__init__.py index e69de29..ddfc4f4 100644 --- a/server/osa/domain/auth/command/__init__.py +++ b/server/osa/domain/auth/command/__init__.py @@ -0,0 +1,33 @@ +"""Auth domain commands.""" + +from .login import ( + CompleteOAuth, + CompleteOAuthHandler, + CompleteOAuthResult, + InitiateLogin, + InitiateLoginHandler, + InitiateLoginResult, +) +from .token import ( + Logout, + LogoutHandler, + LogoutResult, + RefreshTokens, + RefreshTokensHandler, + RefreshTokensResult, +) + +__all__ = [ + "CompleteOAuth", + "CompleteOAuthHandler", + "CompleteOAuthResult", + "InitiateLogin", + "InitiateLoginHandler", + "InitiateLoginResult", + "Logout", + "LogoutHandler", + "LogoutResult", + "RefreshTokens", + "RefreshTokensHandler", + "RefreshTokensResult", +] diff --git a/server/osa/domain/auth/command/login.py b/server/osa/domain/auth/command/login.py new file mode 100644 index 0000000..a54d423 --- /dev/null +++ b/server/osa/domain/auth/command/login.py @@ -0,0 +1,122 @@ +"""Login commands for OAuth authentication flow.""" + +from dataclasses import dataclass +from uuid import uuid4 + +from osa.domain.auth.event import UserAuthenticated +from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventId +from osa.domain.shared.outbox import Outbox + + +class InitiateLogin(Command): + """Command to start OAuth login flow.""" + + callback_url: str # OAuth callback URL (where IdP redirects after auth) + final_redirect_uri: str # Where to redirect user after OAuth completes + provider: str + + +class InitiateLoginResult(Result): + """Result containing authorization URL.""" + + authorization_url: str + + +@dataclass +class InitiateLoginHandler(CommandHandler[InitiateLogin, InitiateLoginResult]): + """Handler for InitiateLogin command.""" + + provider_registry: ProviderRegistry + token_service: TokenService + + async def run(self, cmd: InitiateLogin) -> InitiateLoginResult: + """Generate authorization URL for OAuth login.""" + # Look up the identity provider + identity_provider = self.provider_registry.get(cmd.provider) + if identity_provider is None: + raise NotFoundError( + f"Unknown identity provider: {cmd.provider}", + code="unknown_provider", + ) + + # Create signed state token (includes redirect_uri, provider, expiry, and nonce) + state = self.token_service.create_oauth_state(cmd.final_redirect_uri, cmd.provider) + + # Get authorization URL from identity provider + authorization_url = identity_provider.get_authorization_url( + state=state, + redirect_uri=cmd.callback_url, + ) + + return InitiateLoginResult(authorization_url=authorization_url) + + +class CompleteOAuth(Command): + """Command to complete OAuth flow with authorization code.""" + + code: str + callback_url: str # Must match the one used in authorization + provider: str # The identity provider name (from verified state) + + +class CompleteOAuthResult(Result): + """Result containing user info and tokens.""" + + user_id: str + display_name: str | None + provider: str + external_id: str + access_token: str + refresh_token: str + expires_in: int # Seconds until access token expires + + +@dataclass +class CompleteOAuthHandler(CommandHandler[CompleteOAuth, CompleteOAuthResult]): + """Handler for CompleteOAuth command.""" + + auth_service: AuthService + provider_registry: ProviderRegistry + token_service: TokenService + outbox: Outbox + + async def run(self, cmd: CompleteOAuth) -> CompleteOAuthResult: + """Exchange authorization code for tokens and create/update user.""" + # Look up the identity provider + identity_provider = self.provider_registry.get(cmd.provider) + if identity_provider is None: + raise NotFoundError( + f"Unknown identity provider: {cmd.provider}", + code="unknown_provider", + ) + + user, identity, access_token, refresh_token = await self.auth_service.complete_oauth( + provider=identity_provider, + code=cmd.code, + redirect_uri=cmd.callback_url, + ) + + # Emit UserAuthenticated event + await self.outbox.append( + UserAuthenticated( + id=EventId(uuid4()), + user_id=str(user.id), + provider=identity.provider, + external_id=identity.external_id, + ) + ) + + return CompleteOAuthResult( + user_id=str(user.id), + display_name=user.display_name, + provider=identity.provider, + external_id=identity.external_id, + access_token=access_token, + refresh_token=refresh_token, + expires_in=self.token_service.access_token_expire_seconds, + ) diff --git a/server/osa/domain/auth/command/token.py b/server/osa/domain/auth/command/token.py new file mode 100644 index 0000000..876c800 --- /dev/null +++ b/server/osa/domain/auth/command/token.py @@ -0,0 +1,84 @@ +"""Token commands for refresh and logout operations.""" + +from dataclasses import dataclass +from uuid import uuid4 + +from osa.domain.auth.event import UserLoggedOut +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.command import Command, CommandHandler, Result +from osa.domain.shared.event import EventId +from osa.domain.shared.outbox import Outbox + + +class RefreshTokens(Command): + """Command to refresh access token using refresh token.""" + + refresh_token: str + + +class RefreshTokensResult(Result): + """Result containing new tokens.""" + + access_token: str + refresh_token: str + expires_in: int + + +@dataclass +class RefreshTokensHandler(CommandHandler[RefreshTokens, RefreshTokensResult]): + """Handler for RefreshTokens command.""" + + auth_service: AuthService + token_service: TokenService + + async def run(self, cmd: RefreshTokens) -> RefreshTokensResult: + """Refresh tokens using refresh token rotation.""" + _user, access_token, new_refresh_token = await self.auth_service.refresh_tokens( + cmd.refresh_token + ) + + return RefreshTokensResult( + access_token=access_token, + refresh_token=new_refresh_token, + expires_in=self.token_service.access_token_expire_seconds, + ) + + +class Logout(Command): + """Command to logout and revoke refresh token family.""" + + refresh_token: str + + +class LogoutResult(Result): + """Result for logout operation.""" + + success: bool + + +@dataclass +class LogoutHandler(CommandHandler[Logout, LogoutResult]): + """Handler for Logout command.""" + + auth_service: AuthService + outbox: Outbox + + async def run(self, cmd: Logout) -> LogoutResult: + """Logout by revoking refresh token family.""" + # Get user_id before revoking (for event emission) + user_id = await self.auth_service.get_user_id_from_refresh_token(cmd.refresh_token) + + # Revoke tokens + success = await self.auth_service.logout(cmd.refresh_token) + + # Emit UserLoggedOut event if we had a valid user + if user_id is not None: + await self.outbox.append( + UserLoggedOut( + id=EventId(uuid4()), + user_id=str(user_id), + ) + ) + + return LogoutResult(success=success) diff --git a/server/osa/domain/auth/event/__init__.py b/server/osa/domain/auth/event/__init__.py index e69de29..b5e31c1 100644 --- a/server/osa/domain/auth/event/__init__.py +++ b/server/osa/domain/auth/event/__init__.py @@ -0,0 +1,5 @@ +"""Auth domain events.""" + +from .events import UserAuthenticated, UserLoggedOut + +__all__ = ["UserAuthenticated", "UserLoggedOut"] diff --git a/server/osa/domain/auth/event/events.py b/server/osa/domain/auth/event/events.py new file mode 100644 index 0000000..63504fe --- /dev/null +++ b/server/osa/domain/auth/event/events.py @@ -0,0 +1,19 @@ +"""Domain events for the auth domain.""" + +from osa.domain.shared.event import Event, EventId + + +class UserAuthenticated(Event): + """Emitted when a user successfully authenticates.""" + + id: EventId + user_id: str + provider: str + external_id: str + + +class UserLoggedOut(Event): + """Emitted when a user logs out.""" + + id: EventId + user_id: str diff --git a/server/osa/domain/auth/model/__init__.py b/server/osa/domain/auth/model/__init__.py index e69de29..e59c5bf 100644 --- a/server/osa/domain/auth/model/__init__.py +++ b/server/osa/domain/auth/model/__init__.py @@ -0,0 +1,17 @@ +"""Auth domain models.""" + +from .identity import Identity +from .token import RefreshToken +from .user import User +from .value import IdentityId, OrcidId, RefreshTokenId, TokenFamilyId, UserId + +__all__ = [ + "Identity", + "IdentityId", + "OrcidId", + "RefreshToken", + "RefreshTokenId", + "TokenFamilyId", + "User", + "UserId", +] diff --git a/server/osa/domain/auth/model/identity.py b/server/osa/domain/auth/model/identity.py new file mode 100644 index 0000000..f86090b --- /dev/null +++ b/server/osa/domain/auth/model/identity.py @@ -0,0 +1,46 @@ +"""Identity entity for the auth domain.""" + +from datetime import UTC, datetime +from typing import Any + +from osa.domain.auth.model.value import IdentityId, UserId +from osa.domain.shared.model.entity import Entity + + +class Identity(Entity): + """A link between a User and an external identity provider. + + Examples: + - ORCiD: provider="orcid", external_id="0000-0001-2345-6789" + - SAML: provider="saml:university.edu", external_id="jdoe@university.edu" + + Invariants: + - `(provider, external_id)` is globally unique + - `user_id` is immutable after creation + - `provider` and `external_id` are immutable after creation + """ + + id: IdentityId + user_id: UserId + provider: str + external_id: str + metadata: dict[str, Any] | None = None # Provider-specific data (name, email) + created_at: datetime + + @classmethod + def create( + cls, + user_id: UserId, + provider: str, + external_id: str, + metadata: dict[str, Any] | None = None, + ) -> "Identity": + """Create a new identity link.""" + return cls( + id=IdentityId.generate(), + user_id=user_id, + provider=provider, + external_id=external_id, + metadata=metadata, + created_at=datetime.now(UTC), + ) diff --git a/server/osa/domain/auth/model/token.py b/server/osa/domain/auth/model/token.py new file mode 100644 index 0000000..409214c --- /dev/null +++ b/server/osa/domain/auth/model/token.py @@ -0,0 +1,68 @@ +"""RefreshToken entity for the auth domain.""" + +from datetime import UTC, datetime, timedelta + +from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId +from osa.domain.shared.model.entity import Entity + + +class RefreshToken(Entity): + """An opaque refresh token for session management. + + Tokens belong to a "family" for theft detection. When a token is refreshed, + the new token inherits the family_id. If a revoked token is reused, the + entire family is revoked (indicating potential theft). + + Invariants: + - `token_hash` is a SHA256 hash (64 hex characters) + - `expires_at` is always in the future at creation time + - Once `revoked_at` is set, it cannot be unset + """ + + id: RefreshTokenId + user_id: UserId + token_hash: str # SHA256 hash of the actual token value + family_id: TokenFamilyId + expires_at: datetime + created_at: datetime + revoked_at: datetime | None = None + + @property + def is_valid(self) -> bool: + """Token is valid if not revoked and not expired.""" + return self.revoked_at is None and self.expires_at > datetime.now(UTC) + + @property + def is_revoked(self) -> bool: + """Check if the token has been revoked.""" + return self.revoked_at is not None + + @property + def is_expired(self) -> bool: + """Check if the token has expired.""" + return self.expires_at <= datetime.now(UTC) + + def revoke(self) -> None: + """Mark this token as revoked.""" + if self.revoked_at is None: + self.revoked_at = datetime.now(UTC) + + @classmethod + def create( + cls, + user_id: UserId, + token_hash: str, + family_id: TokenFamilyId, + expires_in_days: int = 7, + ) -> "RefreshToken": + """Create a new refresh token.""" + now = datetime.now(UTC) + return cls( + id=RefreshTokenId.generate(), + user_id=user_id, + token_hash=token_hash, + family_id=family_id, + expires_at=now + timedelta(days=expires_in_days), + created_at=now, + revoked_at=None, + ) diff --git a/server/osa/domain/auth/model/user.py b/server/osa/domain/auth/model/user.py new file mode 100644 index 0000000..95c8ca0 --- /dev/null +++ b/server/osa/domain/auth/model/user.py @@ -0,0 +1,40 @@ +"""User aggregate for the auth domain.""" + +from datetime import UTC, datetime + +from osa.domain.auth.model.value import UserId +from osa.domain.shared.model.aggregate import Aggregate + + +class User(Aggregate): + """An authenticated user in the OSA system. + + Users are created on first authentication via any identity provider. + A user may have multiple linked identities (e.g., ORCiD + institutional SAML). + + Invariants: + - `id` is immutable after creation + - `created_at` is immutable after creation + - `updated_at` is set on any modification + """ + + id: UserId + display_name: str | None + created_at: datetime + updated_at: datetime | None = None + + @classmethod + def create(cls, display_name: str | None = None) -> "User": + """Create a new user.""" + now = datetime.now(UTC) + return cls( + id=UserId.generate(), + display_name=display_name, + created_at=now, + updated_at=None, + ) + + def update_display_name(self, display_name: str | None) -> None: + """Update the user's display name.""" + self.display_name = display_name + self.updated_at = datetime.now(UTC) diff --git a/server/osa/domain/auth/model/value.py b/server/osa/domain/auth/model/value.py new file mode 100644 index 0000000..5c01782 --- /dev/null +++ b/server/osa/domain/auth/model/value.py @@ -0,0 +1,111 @@ +"""Value objects for the auth domain.""" + +import re +from dataclasses import dataclass +from uuid import UUID, uuid4 + +from pydantic import RootModel, field_validator + + +class UserId(RootModel[UUID]): + """Unique identifier for a User.""" + + @classmethod + def generate(cls) -> "UserId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class IdentityId(RootModel[UUID]): + """Unique identifier for an Identity.""" + + @classmethod + def generate(cls) -> "IdentityId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class RefreshTokenId(RootModel[UUID]): + """Unique identifier for a RefreshToken.""" + + @classmethod + def generate(cls) -> "RefreshTokenId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +class TokenFamilyId(RootModel[UUID]): + """Identifier for a token family. + + All refresh tokens from a single login session share a family_id. + Used for theft detection: if a revoked token is reused, the entire + family is invalidated. + """ + + @classmethod + def generate(cls) -> "TokenFamilyId": + return cls(uuid4()) + + def __str__(self) -> str: + return str(self.root) + + def __hash__(self) -> int: + return hash(self.root) + + +ORCID_PATTERN = re.compile(r"^\d{4}-\d{4}-\d{4}-\d{3}[\dX]$") + + +@dataclass(frozen=True) +class ProviderIdentity: + """An external identity from an identity provider. + + Encapsulates provider + external_id together since they're always used as a pair. + """ + + provider: str # e.g., "orcid", "google" + external_id: str # Provider-specific user ID + + +@dataclass(frozen=True) +class CurrentUser: + """Authenticated user context extracted from JWT token.""" + + user_id: "UserId" + identity: ProviderIdentity + + +class OrcidId(RootModel[str]): + """An ORCiD identifier (e.g., 0000-0001-2345-6789). + + ORCiD IDs are 16-digit numbers displayed as four groups of four, + with a checksum character (digit or X) at the end. + """ + + @field_validator("root") + @classmethod + def validate_orcid_format(cls, v: str) -> str: + if not ORCID_PATTERN.match(v): + raise ValueError(f"Invalid ORCiD format: {v}") + return v + + def __str__(self) -> str: + return self.root + + def __hash__(self) -> int: + return hash(self.root) diff --git a/server/osa/domain/auth/port/__init__.py b/server/osa/domain/auth/port/__init__.py index e69de29..9f67041 100644 --- a/server/osa/domain/auth/port/__init__.py +++ b/server/osa/domain/auth/port/__init__.py @@ -0,0 +1,12 @@ +"""Auth domain ports.""" + +from .identity_provider import IdentityInfo, IdentityProvider +from .repository import IdentityRepository, RefreshTokenRepository, UserRepository + +__all__ = [ + "IdentityInfo", + "IdentityProvider", + "IdentityRepository", + "RefreshTokenRepository", + "UserRepository", +] diff --git a/server/osa/domain/auth/port/identity_provider.py b/server/osa/domain/auth/port/identity_provider.py new file mode 100644 index 0000000..24c7305 --- /dev/null +++ b/server/osa/domain/auth/port/identity_provider.py @@ -0,0 +1,64 @@ +"""Identity provider port for the auth domain.""" + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Protocol + +from osa.domain.shared.port import Port + + +@dataclass(frozen=True) +class IdentityInfo: + """Information returned by an identity provider after successful auth.""" + + provider: str # e.g., "orcid", "google", "saml" + external_id: str # Provider-specific user ID + display_name: str | None + email: str | None # May not be available from all providers + raw_data: dict[str, Any] # Full provider response for extensibility + + +class IdentityProvider(Port, Protocol): + """Port for external identity provider integrations. + + Implementations are adapters in infrastructure/ (e.g., OrcidIdentityProvider). + """ + + @property + @abstractmethod + def provider_name(self) -> str: + """Unique identifier for this provider (e.g., 'orcid').""" + ... + + @abstractmethod + def get_authorization_url(self, state: str, redirect_uri: str) -> str: + """Generate URL to redirect user for authentication. + + Args: + state: CSRF protection token (random, stored in session) + redirect_uri: Where the IdP should redirect after auth + + Returns: + Full URL to redirect the user to + """ + ... + + @abstractmethod + async def exchange_code( + self, + code: str, + redirect_uri: str, + ) -> IdentityInfo: + """Exchange authorization code for identity information. + + Args: + code: Authorization code from IdP callback + redirect_uri: Must match the redirect_uri used in authorization URL + + Returns: + IdentityInfo with user details from the provider + + Raises: + ExternalServiceError: If the IdP request fails + """ + ... diff --git a/server/osa/domain/auth/port/provider_registry.py b/server/osa/domain/auth/port/provider_registry.py new file mode 100644 index 0000000..4b718c4 --- /dev/null +++ b/server/osa/domain/auth/port/provider_registry.py @@ -0,0 +1,47 @@ +"""Provider registry port for the auth domain.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.shared.port import Port + + +class ProviderRegistry(Port, Protocol): + """Registry of available identity providers. + + Allows looking up identity providers by name and checking + which providers are configured/available. + """ + + @abstractmethod + def get(self, provider: str) -> IdentityProvider | None: + """Get an identity provider by name. + + Args: + provider: The provider name (e.g., "orcid", "google") + + Returns: + The identity provider if available, None otherwise + """ + ... + + @abstractmethod + def available_providers(self) -> list[str]: + """Get list of available provider names. + + Returns: + List of provider names that can be used for authentication + """ + ... + + def is_available(self, provider: str) -> bool: + """Check if a provider is available. + + Args: + provider: The provider name to check + + Returns: + True if the provider is available + """ + return provider in self.available_providers() diff --git a/server/osa/domain/auth/port/repository.py b/server/osa/domain/auth/port/repository.py new file mode 100644 index 0000000..a2ca4a0 --- /dev/null +++ b/server/osa/domain/auth/port/repository.py @@ -0,0 +1,88 @@ +"""Repository ports for the auth domain.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + IdentityId, + RefreshTokenId, + TokenFamilyId, + UserId, +) +from osa.domain.shared.port import Port + + +class UserRepository(Port, Protocol): + """Repository for User aggregate persistence.""" + + @abstractmethod + async def get(self, user_id: UserId) -> User | None: + """Get a user by ID.""" + ... + + @abstractmethod + async def save(self, user: User) -> None: + """Save a user (create or update).""" + ... + + +class IdentityRepository(Port, Protocol): + """Repository for Identity entity persistence.""" + + @abstractmethod + async def get(self, identity_id: IdentityId) -> Identity | None: + """Get an identity by ID.""" + ... + + @abstractmethod + async def get_by_provider_and_external_id( + self, provider: str, external_id: str + ) -> Identity | None: + """Get an identity by provider and external ID.""" + ... + + @abstractmethod + async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + """Get all identities for a user.""" + ... + + @abstractmethod + async def save(self, identity: Identity) -> None: + """Save an identity.""" + ... + + +class RefreshTokenRepository(Port, Protocol): + """Repository for RefreshToken entity persistence.""" + + @abstractmethod + async def get(self, token_id: RefreshTokenId) -> RefreshToken | None: + """Get a refresh token by ID.""" + ... + + @abstractmethod + async def get_by_token_hash( + self, token_hash: str, *, for_update: bool = False + ) -> RefreshToken | None: + """Get a refresh token by its hash. + + Args: + token_hash: The hash of the token to find. + for_update: If True, acquire a row-level lock (SELECT FOR UPDATE) + to prevent concurrent modifications. Use this when the token + will be modified after retrieval (e.g., during refresh). + """ + ... + + @abstractmethod + async def save(self, token: RefreshToken) -> None: + """Save a refresh token.""" + ... + + @abstractmethod + async def revoke_family(self, family_id: TokenFamilyId) -> int: + """Revoke all tokens in a family. Returns count of revoked tokens.""" + ... diff --git a/server/osa/domain/auth/service/__init__.py b/server/osa/domain/auth/service/__init__.py index e69de29..0e01a6e 100644 --- a/server/osa/domain/auth/service/__init__.py +++ b/server/osa/domain/auth/service/__init__.py @@ -0,0 +1,6 @@ +"""Auth domain services.""" + +from .auth import AuthService +from .token import TokenService + +__all__ = ["AuthService", "TokenService"] diff --git a/server/osa/domain/auth/service/auth.py b/server/osa/domain/auth/service/auth.py new file mode 100644 index 0000000..cf1c25b --- /dev/null +++ b/server/osa/domain/auth/service/auth.py @@ -0,0 +1,276 @@ +"""Auth service for orchestrating authentication flows.""" + +import logging + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ProviderIdentity, TokenFamilyId, UserId +from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.service import Service + +logger = logging.getLogger(__name__) + + +class AuthService(Service): + """Orchestrates authentication flows. + + - initiate_login: Generate authorization URL + - complete_oauth: Exchange code for tokens, create/update user + - refresh_tokens: Issue new tokens from refresh token + - logout: Revoke refresh token family + """ + + _user_repo: UserRepository + _identity_repo: IdentityRepository + _refresh_token_repo: RefreshTokenRepository + _token_service: TokenService + _outbox: Outbox + + async def initiate_login( + self, + provider: IdentityProvider, + state: str, + redirect_uri: str, + ) -> str: + """Generate the authorization URL for OAuth login. + + Args: + provider: The identity provider to use + state: CSRF protection token (caller should store this) + redirect_uri: Where the IdP should redirect after auth + + Returns: + Authorization URL to redirect the user to + """ + return provider.get_authorization_url(state, redirect_uri) + + async def complete_oauth( + self, + provider: IdentityProvider, + code: str, + redirect_uri: str, + ) -> tuple[User, Identity, str, str]: + """Complete OAuth flow and issue tokens. + + Args: + provider: The identity provider + code: Authorization code from callback + redirect_uri: Must match the one used in authorization + + Returns: + Tuple of (user, identity, access_token, refresh_token) + """ + # Exchange code for identity info + identity_info = await provider.exchange_code(code, redirect_uri) + + # Find or create user and identity + user, identity = await self._find_or_create_user(identity_info) + + # Create tokens + access_token, refresh_token = await self._create_tokens(user, identity) + + logger.info( + "User authenticated: user_id=%s, provider=%s, external_id=%s", + user.id, + identity.provider, + identity.external_id, + ) + + return user, identity, access_token, refresh_token + + async def refresh_tokens( + self, + refresh_token_raw: str, + ) -> tuple[User, str, str]: + """Refresh access token using refresh token. + + Implements token rotation: old refresh token is revoked, + new one issued in same family. + + Args: + refresh_token_raw: The raw refresh token from client + + Returns: + Tuple of (user, new_access_token, new_refresh_token) + + Raises: + InvalidStateError: If refresh token is invalid, expired, or revoked + """ + from osa.domain.shared.error import InvalidStateError + + token_hash = self._token_service.hash_token(refresh_token_raw) + # Lock the row to prevent concurrent refresh attempts (race condition) + stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash, for_update=True) + + if stored_token is None: + raise InvalidStateError("Invalid refresh token", code="invalid_refresh_token") + + if stored_token.is_revoked: + # Potential theft detected - revoke entire family + await self._refresh_token_repo.revoke_family(stored_token.family_id) + logger.warning( + "Refresh token reuse detected, family revoked: family_id=%s", + stored_token.family_id, + ) + raise InvalidStateError( + "Token family revoked - please login again", + code="token_family_revoked", + ) + + if stored_token.is_expired: + raise InvalidStateError("Refresh token expired", code="refresh_token_expired") + + # Revoke old token + stored_token.revoke() + await self._refresh_token_repo.save(stored_token) + + # Get user and their primary identity + user = await self._user_repo.get(stored_token.user_id) + if user is None: + raise InvalidStateError("User not found", code="user_not_found") + + primary_identity = await self.get_primary_identity(user.id) + + if primary_identity is None: + raise InvalidStateError("User has no identity", code="no_identity") + + # Issue new tokens in same family + raw_token, token_hash = self._token_service.create_refresh_token() + new_refresh_token = RefreshToken.create( + user_id=user.id, + token_hash=token_hash, + family_id=stored_token.family_id, + expires_in_days=self._token_service.refresh_token_expire_days, + ) + await self._refresh_token_repo.save(new_refresh_token) + + access_token = self._token_service.create_access_token( + user_id=user.id, + identity=primary_identity, + ) + + logger.info("Tokens refreshed: user_id=%s", user.id) + + return user, access_token, raw_token + + async def logout(self, refresh_token_raw: str) -> bool: + """Logout by revoking refresh token family. + + Args: + refresh_token_raw: The raw refresh token + + Returns: + True if tokens were revoked + """ + token_hash = self._token_service.hash_token(refresh_token_raw) + stored_token = await self._refresh_token_repo.get_by_token_hash(token_hash) + + if stored_token is None: + # Token not found, but logout succeeds anyway + return True + + revoked_count = await self._refresh_token_repo.revoke_family(stored_token.family_id) + logger.info( + "User logged out: user_id=%s, revoked_tokens=%d", + stored_token.user_id, + revoked_count, + ) + + return True + + async def get_user_by_id(self, user_id: UserId) -> User | None: + """Get a user by their ID.""" + return await self._user_repo.get(user_id) + + async def get_primary_identity(self, user_id: UserId) -> ProviderIdentity | None: + """Get the primary identity for a user. + + Returns the first identity found for the user. In the future, + this could be extended to support multiple identities with a + designated primary. + """ + identities = await self._identity_repo.get_by_user_id(user_id) + if not identities: + return None + first = identities[0] + return ProviderIdentity(provider=first.provider, external_id=first.external_id) + + async def get_user_id_from_refresh_token(self, raw_token: str) -> UserId | None: + """Get the user ID associated with a refresh token. + + Args: + raw_token: The raw refresh token string + + Returns: + The user ID if token exists, None otherwise + """ + token_hash = self._token_service.hash_token(raw_token) + stored = await self._refresh_token_repo.get_by_token_hash(token_hash) + return stored.user_id if stored else None + + async def _find_or_create_user(self, identity_info: IdentityInfo) -> tuple[User, Identity]: + """Find existing user by identity or create new one.""" + # Check if identity already exists + existing_identity = await self._identity_repo.get_by_provider_and_external_id( + identity_info.provider, identity_info.external_id + ) + + if existing_identity: + # User exists, return them + user = await self._user_repo.get(existing_identity.user_id) + if user is None: + # Orphaned identity - shouldn't happen with CASCADE + raise RuntimeError(f"Identity exists without user: {existing_identity.id}") + return user, existing_identity + + # Create new user and identity + user = User.create(display_name=identity_info.display_name) + await self._user_repo.save(user) + + identity = Identity.create( + user_id=user.id, + provider=identity_info.provider, + external_id=identity_info.external_id, + metadata=identity_info.raw_data, + ) + await self._identity_repo.save(identity) + + logger.info( + "New user created: user_id=%s, provider=%s", + user.id, + identity_info.provider, + ) + + return user, identity + + async def _create_tokens(self, user: User, identity: Identity) -> tuple[str, str]: + """Create access and refresh tokens for a user.""" + # Create refresh token + raw_token, token_hash = self._token_service.create_refresh_token() + refresh_token = RefreshToken.create( + user_id=user.id, + token_hash=token_hash, + family_id=TokenFamilyId.generate(), + expires_in_days=self._token_service.refresh_token_expire_days, + ) + await self._refresh_token_repo.save(refresh_token) + + # Create access token + provider_identity = ProviderIdentity( + provider=identity.provider, + external_id=identity.external_id, + ) + access_token = self._token_service.create_access_token( + user_id=user.id, + identity=provider_identity, + ) + + return access_token, raw_token diff --git a/server/osa/domain/auth/service/token.py b/server/osa/domain/auth/service/token.py new file mode 100644 index 0000000..f8948bd --- /dev/null +++ b/server/osa/domain/auth/service/token.py @@ -0,0 +1,197 @@ +"""Token service for JWT creation and validation.""" + +import hashlib +import hmac +import json +import logging +import secrets +import time +from base64 import urlsafe_b64decode, urlsafe_b64encode +from datetime import UTC, datetime, timedelta +from typing import Any + +import jwt + +from osa.config import JwtConfig +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.shared.service import Service + +logger = logging.getLogger(__name__) + +# OAuth state validity period (5 minutes) +STATE_EXPIRY_SECONDS = 300 + + +class TokenService(Service): + """Service for JWT access token and refresh token operations. + + - Access tokens are JWTs (HS256) with user claims + - Refresh tokens are opaque random strings, stored as hashes in the database + - OAuth state tokens are signed payloads for CSRF protection + """ + + _config: JwtConfig + + def create_access_token( + self, + user_id: UserId, + identity: ProviderIdentity, + additional_claims: dict[str, Any] | None = None, + ) -> str: + """Create a JWT access token. + + Args: + user_id: The user's internal ID + identity: The user's external identity (provider + external_id) + additional_claims: Optional extra claims to include + + Returns: + Encoded JWT string + """ + now = datetime.now(UTC) + expires_at = now + timedelta(minutes=self._config.access_token_expire_minutes) + + payload = { + "sub": str(user_id), + "provider": identity.provider, + "external_id": identity.external_id, + "aud": "authenticated", + "iat": int(now.timestamp()), + "exp": int(expires_at.timestamp()), + "jti": secrets.token_hex(16), + } + + if additional_claims: + payload.update(additional_claims) + + return jwt.encode( + payload, + self._config.secret, + algorithm=self._config.algorithm, + ) + + def validate_access_token(self, token: str) -> dict[str, Any]: + """Validate and decode a JWT access token. + + Args: + token: The JWT string to validate + + Returns: + Decoded payload dict + + Raises: + jwt.InvalidTokenError: If token is invalid or expired + """ + return jwt.decode( + token, + self._config.secret, + algorithms=[self._config.algorithm], + audience="authenticated", + ) + + def create_refresh_token(self) -> tuple[str, str]: + """Create a new refresh token. + + Returns: + Tuple of (raw_token, token_hash) + - raw_token: Send to client + - token_hash: Store in database + """ + raw_token = secrets.token_urlsafe(32) + token_hash = self.hash_token(raw_token) + return raw_token, token_hash + + @staticmethod + def hash_token(raw_token: str) -> str: + """Create SHA256 hash of a token. + + Args: + raw_token: The raw token string + + Returns: + Hex-encoded SHA256 hash (64 characters) + """ + return hashlib.sha256(raw_token.encode()).hexdigest() + + @property + def access_token_expire_seconds(self) -> int: + """Get access token expiry in seconds.""" + return self._config.access_token_expire_minutes * 60 + + @property + def refresh_token_expire_days(self) -> int: + """Get refresh token expiry in days.""" + return self._config.refresh_token_expire_days + + def create_oauth_state(self, redirect_uri: str, provider: str) -> str: + """Create a signed, self-verifying OAuth state token. + + The state contains: nonce, redirect_uri, provider, expiry timestamp. + Signed with HMAC-SHA256 using the JWT secret. + + Args: + redirect_uri: The URI to redirect to after OAuth completes + provider: The identity provider name (e.g., "orcid") + + Returns: + URL-safe signed state token in format: payload.signature + """ + payload = { + "nonce": secrets.token_urlsafe(16), + "redirect_uri": redirect_uri, + "provider": provider, + "exp": int(time.time()) + STATE_EXPIRY_SECONDS, + } + payload_bytes = json.dumps(payload, separators=(",", ":")).encode() + payload_b64 = urlsafe_b64encode(payload_bytes).rstrip(b"=").decode() + + signature = hmac.new(self._config.secret.encode(), payload_bytes, hashlib.sha256).digest() + signature_b64 = urlsafe_b64encode(signature).rstrip(b"=").decode() + + return f"{payload_b64}.{signature_b64}" + + def verify_oauth_state(self, state: str) -> tuple[str, str] | None: + """Verify a signed state token and return the redirect_uri and provider if valid. + + Args: + state: The signed state token to verify + + Returns: + Tuple of (redirect_uri, provider) if valid, None if invalid or expired + """ + try: + parts = state.split(".") + if len(parts) != 2: + return None + + payload_b64, signature_b64 = parts + + # Restore base64 padding + payload_bytes = urlsafe_b64decode(payload_b64 + "==") + signature = urlsafe_b64decode(signature_b64 + "==") + + # Verify signature + expected_sig = hmac.new( + self._config.secret.encode(), payload_bytes, hashlib.sha256 + ).digest() + if not hmac.compare_digest(signature, expected_sig): + logger.warning("OAuth state signature verification failed") + return None + + # Parse and check expiry + payload = json.loads(payload_bytes) + if payload.get("exp", 0) < time.time(): + logger.warning("OAuth state expired") + return None + + redirect_uri = payload.get("redirect_uri") + provider = payload.get("provider") + if not redirect_uri or not provider: + logger.warning("OAuth state missing redirect_uri or provider") + return None + + return redirect_uri, provider + + except Exception as e: + logger.warning("OAuth state verification error: %s", e) + return None diff --git a/server/osa/domain/auth/adapter/__init__.py b/server/osa/domain/auth/util/__init__.py similarity index 100% rename from server/osa/domain/auth/adapter/__init__.py rename to server/osa/domain/auth/util/__init__.py diff --git a/server/osa/domain/auth/util/di/__init__.py b/server/osa/domain/auth/util/di/__init__.py new file mode 100644 index 0000000..22bb012 --- /dev/null +++ b/server/osa/domain/auth/util/di/__init__.py @@ -0,0 +1,5 @@ +"""DI providers for auth domain.""" + +from .provider import AuthProvider + +__all__ = ["AuthProvider"] diff --git a/server/osa/domain/auth/util/di/provider.py b/server/osa/domain/auth/util/di/provider.py new file mode 100644 index 0000000..8cce9b2 --- /dev/null +++ b/server/osa/domain/auth/util/di/provider.py @@ -0,0 +1,104 @@ +"""DI provider for auth domain.""" + +from uuid import UUID + +import jwt +from dishka import from_context, provide +from fastapi import HTTPException +from starlette.requests import Request + +from osa.config import Config +from osa.domain.auth.command.login import ( + CompleteOAuthHandler, + InitiateLoginHandler, +) +from osa.domain.auth.command.token import LogoutHandler, RefreshTokensHandler +from osa.domain.auth.model.value import CurrentUser, ProviderIdentity, UserId +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.outbox import Outbox +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class AuthProvider(Provider): + """DI provider for auth domain services and handlers.""" + + request = from_context(provides=Request, scope=Scope.UOW) + + # Command Handlers + initiate_login_handler = provide(InitiateLoginHandler, scope=Scope.UOW) + complete_oauth_handler = provide(CompleteOAuthHandler, scope=Scope.UOW) + refresh_tokens_handler = provide(RefreshTokensHandler, scope=Scope.UOW) + logout_handler = provide(LogoutHandler, scope=Scope.UOW) + + @provide(scope=Scope.UOW) + def get_token_service(self, config: Config) -> TokenService: + """Provide TokenService.""" + return TokenService(_config=config.auth.jwt) + + @provide(scope=Scope.UOW) + def get_auth_service( + self, + user_repo: UserRepository, + identity_repo: IdentityRepository, + refresh_token_repo: RefreshTokenRepository, + token_service: TokenService, + outbox: Outbox, + ) -> AuthService: + """Provide AuthService.""" + return AuthService( + _user_repo=user_repo, + _identity_repo=identity_repo, + _refresh_token_repo=refresh_token_repo, + _token_service=token_service, + _outbox=outbox, + ) + + @provide(scope=Scope.UOW) + def get_current_user( + self, + request: Request, + token_service: TokenService, + ) -> CurrentUser: + """Extract and validate CurrentUser from JWT in Authorization header. + + Raises: + HTTPException: If token is missing, expired, or invalid + """ + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + raise HTTPException( + status_code=401, + detail={"code": "missing_token", "message": "Authorization header required"}, + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = auth_header[7:] # Remove "Bearer " prefix + + try: + payload = token_service.validate_access_token(token) + return CurrentUser( + user_id=UserId(UUID(payload["sub"])), + identity=ProviderIdentity( + provider=payload["provider"], + external_id=payload["external_id"], + ), + ) + except jwt.ExpiredSignatureError as e: + raise HTTPException( + status_code=401, + detail={"code": "token_expired", "message": "Token has expired"}, + headers={"WWW-Authenticate": "Bearer"}, + ) from e + except jwt.InvalidTokenError as e: + raise HTTPException( + status_code=401, + detail={"code": "invalid_token", "message": "Invalid token"}, + headers={"WWW-Authenticate": "Bearer"}, + ) from e diff --git a/server/osa/infrastructure/auth/__init__.py b/server/osa/infrastructure/auth/__init__.py new file mode 100644 index 0000000..fc36309 --- /dev/null +++ b/server/osa/infrastructure/auth/__init__.py @@ -0,0 +1,5 @@ +"""Auth infrastructure adapters.""" + +from .di import AuthInfraProvider + +__all__ = ["AuthInfraProvider"] diff --git a/server/osa/infrastructure/auth/di.py b/server/osa/infrastructure/auth/di.py new file mode 100644 index 0000000..40a354e --- /dev/null +++ b/server/osa/infrastructure/auth/di.py @@ -0,0 +1,71 @@ +"""DI provider for auth infrastructure.""" + +import httpx +from dishka import provide + +from osa.config import Config +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.auth.port.provider_registry import ProviderRegistry +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.infrastructure.auth.orcid import OrcidIdentityProvider +from osa.infrastructure.auth.provider_registry import InMemoryProviderRegistry +from osa.infrastructure.persistence.repository.auth import ( + PostgresIdentityRepository, + PostgresRefreshTokenRepository, + PostgresUserRepository, +) +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + +# HTTP client timeout configuration +_HTTP_TIMEOUT = httpx.Timeout( + connect=5.0, # Connection timeout + read=10.0, # Read timeout + write=5.0, # Write timeout + pool=5.0, # Pool timeout +) + + +class AuthInfraProvider(Provider): + """DI provider for auth infrastructure adapters.""" + + # Repository adapters + user_repo = provide( + PostgresUserRepository, + scope=Scope.UOW, + provides=UserRepository, + ) + identity_repo = provide( + PostgresIdentityRepository, + scope=Scope.UOW, + provides=IdentityRepository, + ) + refresh_token_repo = provide( + PostgresRefreshTokenRepository, + scope=Scope.UOW, + provides=RefreshTokenRepository, + ) + + @provide(scope=Scope.APP) + def get_auth_http_client(self) -> httpx.AsyncClient: + """Shared HTTP client for auth operations (connection pooling).""" + return httpx.AsyncClient(timeout=_HTTP_TIMEOUT) + + @provide(scope=Scope.APP) + def get_provider_registry( + self, config: Config, http_client: httpx.AsyncClient + ) -> ProviderRegistry: + """Provide ProviderRegistry with configured identity providers.""" + providers: dict[str, IdentityProvider] = {} + + # Register ORCID if configured + if config.auth.orcid.client_id: + providers["orcid"] = OrcidIdentityProvider( + config=config.auth.orcid, http_client=http_client + ) + + return InMemoryProviderRegistry(providers) diff --git a/server/osa/infrastructure/auth/orcid.py b/server/osa/infrastructure/auth/orcid.py new file mode 100644 index 0000000..a5db378 --- /dev/null +++ b/server/osa/infrastructure/auth/orcid.py @@ -0,0 +1,101 @@ +"""ORCiD identity provider adapter.""" + +import logging +from urllib.parse import urlencode + +import httpx + +from osa.config import OrcidConfig +from osa.domain.auth.port.identity_provider import IdentityInfo, IdentityProvider +from osa.domain.shared.error import ExternalServiceError + +logger = logging.getLogger(__name__) + + +class OrcidIdentityProvider(IdentityProvider): + """IdentityProvider implementation for ORCiD OAuth.""" + + def __init__(self, config: OrcidConfig, http_client: httpx.AsyncClient) -> None: + self._config = config + self._http = http_client + + @property + def provider_name(self) -> str: + return "orcid" + + def get_authorization_url(self, state: str, redirect_uri: str) -> str: + """Generate ORCiD authorization URL.""" + params = { + "client_id": self._config.client_id, + "response_type": "code", + "scope": "/authenticate", + "redirect_uri": redirect_uri, + "state": state, + } + return f"{self._config.base_url}/oauth/authorize?{urlencode(params)}" + + async def exchange_code( + self, + code: str, + redirect_uri: str, + ) -> IdentityInfo: + """Exchange authorization code for identity information.""" + token_url = f"{self._config.base_url}/oauth/token" + + data = { + "client_id": self._config.client_id, + "client_secret": self._config.client_secret, + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + } + + try: + response = await self._http.post( + token_url, + data=data, + headers={"Accept": "application/json"}, + ) + + if response.status_code != 200: + logger.error( + "ORCiD token exchange failed: status=%d, body=%s", + response.status_code, + response.text, + ) + raise ExternalServiceError( + f"ORCiD token exchange failed: {response.status_code}", + code="idp_unavailable", + ) + + token_data = response.json() + + except httpx.RequestError as e: + logger.exception("ORCiD request failed: %s", e) + raise ExternalServiceError( + "Failed to connect to ORCiD", + code="idp_unavailable", + ) from e + + # ORCiD returns user info directly in token response + # { + # "access_token": "...", + # "token_type": "bearer", + # "scope": "/authenticate", + # "name": "Jane Doe", + # "orcid": "0000-0001-2345-6789" + # } + orcid_id = token_data.get("orcid") + if not orcid_id: + raise ExternalServiceError( + "ORCiD response missing orcid field", + code="oauth_error", + ) + + return IdentityInfo( + provider="orcid", + external_id=orcid_id, + display_name=token_data.get("name"), + email=None, # ORCiD doesn't return email in basic auth + raw_data=token_data, + ) diff --git a/server/osa/infrastructure/auth/provider_registry.py b/server/osa/infrastructure/auth/provider_registry.py new file mode 100644 index 0000000..0dde031 --- /dev/null +++ b/server/osa/infrastructure/auth/provider_registry.py @@ -0,0 +1,37 @@ +"""Provider registry implementation.""" + +from osa.domain.auth.port.identity_provider import IdentityProvider +from osa.domain.auth.port.provider_registry import ProviderRegistry + + +class InMemoryProviderRegistry(ProviderRegistry): + """In-memory provider registry. + + Stores a mapping of provider names to their implementations. + Providers are registered at application startup via DI. + """ + + def __init__(self, providers: dict[str, IdentityProvider] | None = None) -> None: + """Initialize registry with optional initial providers. + + Args: + providers: Optional dict mapping provider names to implementations + """ + self._providers: dict[str, IdentityProvider] = providers or {} + + def get(self, provider: str) -> IdentityProvider | None: + """Get an identity provider by name.""" + return self._providers.get(provider) + + def available_providers(self) -> list[str]: + """Get list of available provider names.""" + return list(self._providers.keys()) + + def register(self, name: str, provider: IdentityProvider) -> None: + """Register a provider. + + Args: + name: The provider name + provider: The provider implementation + """ + self._providers[name] = provider diff --git a/server/osa/infrastructure/persistence/repository/auth.py b/server/osa/infrastructure/persistence/repository/auth.py new file mode 100644 index 0000000..ff46522 --- /dev/null +++ b/server/osa/infrastructure/persistence/repository/auth.py @@ -0,0 +1,222 @@ +"""PostgreSQL repository implementations for auth domain.""" + +from datetime import UTC, datetime +from uuid import UUID + +from sqlalchemy import insert, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import ( + IdentityId, + RefreshTokenId, + TokenFamilyId, + UserId, +) +from osa.domain.auth.port.repository import ( + IdentityRepository, + RefreshTokenRepository, + UserRepository, +) +from osa.infrastructure.persistence.tables import ( + identities_table, + refresh_tokens_table, + users_table, +) + + +def _row_to_user(row: dict) -> User: + """Convert a database row to a User model.""" + return User( + id=UserId(UUID(row["id"])), + display_name=row["display_name"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + +def _user_to_dict(user: User) -> dict: + """Convert a User model to a database row dict.""" + return { + "id": str(user.id), + "display_name": user.display_name, + "created_at": user.created_at, + "updated_at": user.updated_at, + } + + +def _row_to_identity(row: dict) -> Identity: + """Convert a database row to an Identity model.""" + return Identity( + id=IdentityId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + provider=row["provider"], + external_id=row["external_id"], + metadata=row["metadata"], + created_at=row["created_at"], + ) + + +def _identity_to_dict(identity: Identity) -> dict: + """Convert an Identity model to a database row dict.""" + return { + "id": str(identity.id), + "user_id": str(identity.user_id), + "provider": identity.provider, + "external_id": identity.external_id, + "metadata": identity.metadata, + "created_at": identity.created_at, + } + + +def _row_to_refresh_token(row: dict) -> RefreshToken: + """Convert a database row to a RefreshToken model.""" + return RefreshToken( + id=RefreshTokenId(UUID(row["id"])), + user_id=UserId(UUID(row["user_id"])), + token_hash=row["token_hash"], + family_id=TokenFamilyId(UUID(row["family_id"])), + expires_at=row["expires_at"], + created_at=row["created_at"], + revoked_at=row["revoked_at"], + ) + + +def _refresh_token_to_dict(token: RefreshToken) -> dict: + """Convert a RefreshToken model to a database row dict.""" + return { + "id": str(token.id), + "user_id": str(token.user_id), + "token_hash": token.token_hash, + "family_id": str(token.family_id), + "expires_at": token.expires_at, + "created_at": token.created_at, + "revoked_at": token.revoked_at, + } + + +class PostgresUserRepository(UserRepository): + """PostgreSQL implementation of UserRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, user_id: UserId) -> User | None: + stmt = select(users_table).where(users_table.c.id == str(user_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_user(dict(row)) if row else None + + async def save(self, user: User) -> None: + user_dict = _user_to_dict(user) + existing = await self.get(user.id) + + if existing: + stmt = update(users_table).where(users_table.c.id == str(user.id)).values(**user_dict) + else: + stmt = insert(users_table).values(**user_dict) + + await self.session.execute(stmt) + await self.session.flush() + + +class PostgresIdentityRepository(IdentityRepository): + """PostgreSQL implementation of IdentityRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, identity_id: IdentityId) -> Identity | None: + stmt = select(identities_table).where(identities_table.c.id == str(identity_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_identity(dict(row)) if row else None + + async def get_by_provider_and_external_id( + self, provider: str, external_id: str + ) -> Identity | None: + stmt = select(identities_table).where( + identities_table.c.provider == provider, + identities_table.c.external_id == external_id, + ) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_identity(dict(row)) if row else None + + async def get_by_user_id(self, user_id: UserId) -> list[Identity]: + stmt = select(identities_table).where(identities_table.c.user_id == str(user_id)) + result = await self.session.execute(stmt) + rows = result.mappings().all() + return [_row_to_identity(dict(row)) for row in rows] + + async def save(self, identity: Identity) -> None: + identity_dict = _identity_to_dict(identity) + existing = await self.get(identity.id) + + if existing: + stmt = ( + update(identities_table) + .where(identities_table.c.id == str(identity.id)) + .values(**identity_dict) + ) + else: + stmt = insert(identities_table).values(**identity_dict) + + await self.session.execute(stmt) + await self.session.flush() + + +class PostgresRefreshTokenRepository(RefreshTokenRepository): + """PostgreSQL implementation of RefreshTokenRepository.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get(self, token_id: RefreshTokenId) -> RefreshToken | None: + stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.id == str(token_id)) + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_refresh_token(dict(row)) if row else None + + async def get_by_token_hash( + self, token_hash: str, *, for_update: bool = False + ) -> RefreshToken | None: + stmt = select(refresh_tokens_table).where(refresh_tokens_table.c.token_hash == token_hash) + if for_update: + stmt = stmt.with_for_update() + result = await self.session.execute(stmt) + row = result.mappings().first() + return _row_to_refresh_token(dict(row)) if row else None + + async def save(self, token: RefreshToken) -> None: + token_dict = _refresh_token_to_dict(token) + existing = await self.get(token.id) + + if existing: + stmt = ( + update(refresh_tokens_table) + .where(refresh_tokens_table.c.id == str(token.id)) + .values(**token_dict) + ) + else: + stmt = insert(refresh_tokens_table).values(**token_dict) + + await self.session.execute(stmt) + await self.session.flush() + + async def revoke_family(self, family_id: TokenFamilyId) -> int: + """Revoke all tokens in a family. Returns count of revoked tokens.""" + now = datetime.now(UTC) + stmt = ( + update(refresh_tokens_table) + .where( + refresh_tokens_table.c.family_id == str(family_id), + refresh_tokens_table.c.revoked_at.is_(None), + ) + .values(revoked_at=now) + ) + result = await self.session.execute(stmt) + await self.session.flush() + return result.rowcount diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index 6b66044..f22ee64 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -3,12 +3,14 @@ from sqlalchemy import ( Column, DateTime, + ForeignKey, Index, Integer, MetaData, String, Table, Text, + UniqueConstraint, text, ) from sqlalchemy.types import JSON @@ -122,3 +124,54 @@ events_table.c.created_at, postgresql_where=text("delivery_status = 'failed'"), ) + + +# ============================================================================ +# USERS TABLE (Authentication) +# ============================================================================ +users_table = Table( + "users", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("display_name", String(255), nullable=True), + Column("created_at", DateTime(timezone=True), nullable=False), + Column("updated_at", DateTime(timezone=True), nullable=True), +) + + +# ============================================================================ +# IDENTITIES TABLE (Authentication) +# ============================================================================ +identities_table = Table( + "identities", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("provider", String(50), nullable=False), # "orcid", "google", etc. + Column("external_id", String(255), nullable=False), # ORCiD ID, Google ID, etc. + Column("metadata", JSON, nullable=True), # Provider-specific data (name, email) + Column("created_at", DateTime(timezone=True), nullable=False), + UniqueConstraint("provider", "external_id", name="uq_identity_provider_external"), +) + +Index("ix_identities_user_id", identities_table.c.user_id) + + +# ============================================================================ +# REFRESH TOKENS TABLE (Authentication) +# ============================================================================ +refresh_tokens_table = Table( + "refresh_tokens", + metadata, + Column("id", String, primary_key=True), # UUID as string + Column("user_id", String, ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("token_hash", String(64), nullable=False), # SHA256 hash + Column("family_id", String, nullable=False), # UUID - for theft detection + Column("expires_at", DateTime(timezone=True), nullable=False), + Column("created_at", DateTime(timezone=True), nullable=False), + Column("revoked_at", DateTime(timezone=True), nullable=True), +) + +Index("ix_refresh_tokens_user_id", refresh_tokens_table.c.user_id) +Index("ix_refresh_tokens_token_hash", refresh_tokens_table.c.token_hash) +Index("ix_refresh_tokens_family_id", refresh_tokens_table.c.family_id) diff --git a/server/pyproject.toml b/server/pyproject.toml index be6bb29..09609ca 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "rich>=14.2.0", "asyncpg>=0.31.0", "psycopg2-binary>=2.9.11", + "pyjwt>=2.11.0", ] [project.scripts] diff --git a/server/tests/conftest.py b/server/tests/conftest.py index e69de29..cf9c15f 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -0,0 +1,7 @@ +"""Global test fixtures.""" + +import os + +# Set JWT secret before any test modules import Config +# This must happen at module load time, not in a fixture +os.environ.setdefault("OSA_AUTH__JWT__SECRET", "test-secret-for-unit-tests-min-32") diff --git a/server/tests/unit/application/api/v1/routes/test_auth_state.py b/server/tests/unit/application/api/v1/routes/test_auth_state.py new file mode 100644 index 0000000..ebe63ca --- /dev/null +++ b/server/tests/unit/application/api/v1/routes/test_auth_state.py @@ -0,0 +1,157 @@ +"""Unit tests for OAuth state token signing/verification.""" + +import time + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.service.token import STATE_EXPIRY_SECONDS, TokenService + + +@pytest.fixture +def token_service() -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret="test-secret-key-for-signing-min-32", + algorithm="HS256", + access_token_expire_minutes=15, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +@pytest.fixture +def token_service_alt_secret() -> TokenService: + """Create a TokenService with a different secret.""" + config = JwtConfig( + secret="different-secret-key-for-testing", + algorithm="HS256", + access_token_expire_minutes=15, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +class TestSignedStateCreation: + """Tests for TokenService.create_oauth_state.""" + + def test_creates_state_with_redirect_uri(self, token_service: TokenService): + """Should create a signed state containing the redirect URI and provider.""" + redirect_uri = "https://example.com/callback" + provider = "orcid" + + state = token_service.create_oauth_state(redirect_uri, provider) + + # State should be format: payload.signature + assert "." in state + parts = state.split(".") + assert len(parts) == 2 + + def test_different_nonces_produce_different_states(self, token_service: TokenService): + """Each state should have a unique nonce.""" + redirect_uri = "https://example.com" + + state1 = token_service.create_oauth_state(redirect_uri, "orcid") + state2 = token_service.create_oauth_state(redirect_uri, "orcid") + + assert state1 != state2 + + def test_state_is_url_safe(self, token_service: TokenService): + """State should only contain URL-safe characters.""" + redirect_uri = "https://example.com/path?query=value" + + state = token_service.create_oauth_state(redirect_uri, "orcid") + + # URL-safe base64 uses only these characters + allowed = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.") + assert all(c in allowed for c in state) + + +class TestSignedStateVerification: + """Tests for TokenService.verify_oauth_state.""" + + def test_verifies_valid_state(self, token_service: TokenService): + """Should return (redirect_uri, provider) for valid state.""" + redirect_uri = "https://example.com/after-login" + provider = "orcid" + + state = token_service.create_oauth_state(redirect_uri, provider) + result = token_service.verify_oauth_state(state) + + assert result is not None + assert result == (redirect_uri, provider) + + def test_rejects_tampered_payload(self, token_service: TokenService): + """Should reject state with tampered payload.""" + state = token_service.create_oauth_state("https://example.com", "orcid") + + # Tamper with the payload (change a character) + parts = state.split(".") + tampered_payload = "x" + parts[0][1:] + tampered_state = f"{tampered_payload}.{parts[1]}" + + result = token_service.verify_oauth_state(tampered_state) + assert result is None + + def test_rejects_tampered_signature(self, token_service: TokenService): + """Should reject state with tampered signature.""" + state = token_service.create_oauth_state("https://example.com", "orcid") + + # Tamper with the signature + parts = state.split(".") + tampered_sig = "x" + parts[1][1:] + tampered_state = f"{parts[0]}.{tampered_sig}" + + result = token_service.verify_oauth_state(tampered_state) + assert result is None + + def test_rejects_wrong_secret( + self, token_service: TokenService, token_service_alt_secret: TokenService + ): + """Should reject state signed with different secret.""" + state = token_service.create_oauth_state("https://example.com", "orcid") + + result = token_service_alt_secret.verify_oauth_state(state) + assert result is None + + def test_rejects_expired_state(self, token_service: TokenService, monkeypatch): + """Should reject expired state.""" + state = token_service.create_oauth_state("https://example.com", "orcid") + + # Fast-forward time past expiry + future_time = time.time() + STATE_EXPIRY_SECONDS + 1 + monkeypatch.setattr(time, "time", lambda: future_time) + + result = token_service.verify_oauth_state(state) + assert result is None + + def test_rejects_malformed_state(self, token_service: TokenService): + """Should reject malformed state strings.""" + # No dot separator + assert token_service.verify_oauth_state("nodot") is None + + # Empty parts + assert token_service.verify_oauth_state(".") is None + assert token_service.verify_oauth_state("payload.") is None + assert token_service.verify_oauth_state(".signature") is None + + # Too many parts + assert token_service.verify_oauth_state("a.b.c") is None + + # Invalid base64 + assert token_service.verify_oauth_state("!!!.???") is None + + def test_rejects_empty_state(self, token_service: TokenService): + """Should reject empty state.""" + assert token_service.verify_oauth_state("") is None + + def test_handles_special_characters_in_redirect_uri(self, token_service: TokenService): + """Should handle redirect URIs with special characters.""" + redirect_uri = "https://example.com/path?foo=bar&baz=qux#fragment" + provider = "orcid" + + state = token_service.create_oauth_state(redirect_uri, provider) + result = token_service.verify_oauth_state(state) + + assert result is not None + assert result == (redirect_uri, provider) diff --git a/server/tests/unit/domain/auth/__init__.py b/server/tests/unit/domain/auth/__init__.py new file mode 100644 index 0000000..0392972 --- /dev/null +++ b/server/tests/unit/domain/auth/__init__.py @@ -0,0 +1 @@ +"""Unit tests for auth domain.""" diff --git a/server/tests/unit/domain/auth/test_auth_service.py b/server/tests/unit/domain/auth/test_auth_service.py new file mode 100644 index 0000000..bef7f98 --- /dev/null +++ b/server/tests/unit/domain/auth/test_auth_service.py @@ -0,0 +1,393 @@ +"""Unit tests for AuthService.""" + +from datetime import UTC, datetime, timedelta +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import IdentityId, RefreshTokenId, TokenFamilyId, UserId +from osa.domain.auth.port.identity_provider import IdentityInfo +from osa.domain.auth.service.auth import AuthService +from osa.domain.auth.service.token import TokenService +from osa.domain.shared.error import InvalidStateError + + +def make_auth_service( + user_repo: AsyncMock | None = None, + identity_repo: AsyncMock | None = None, + refresh_token_repo: AsyncMock | None = None, + token_service: TokenService | None = None, + outbox: AsyncMock | None = None, +) -> AuthService: + """Create an AuthService with mocked dependencies.""" + if user_repo is None: + user_repo = AsyncMock() + if identity_repo is None: + identity_repo = AsyncMock() + if refresh_token_repo is None: + refresh_token_repo = AsyncMock() + if token_service is None: + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + token_service = TokenService(_config=config) + if outbox is None: + outbox = AsyncMock() + + return AuthService( + _user_repo=user_repo, + _identity_repo=identity_repo, + _refresh_token_repo=refresh_token_repo, + _token_service=token_service, + _outbox=outbox, + ) + + +def make_identity_provider(identity_info: IdentityInfo | None = None) -> MagicMock: + """Create a mock identity provider.""" + provider = MagicMock() + provider.provider_name = "orcid" + + if identity_info is None: + identity_info = IdentityInfo( + provider="orcid", + external_id="0000-0001-2345-6789", + display_name="Jane Doe", + email=None, + raw_data={"name": "Jane Doe", "orcid": "0000-0001-2345-6789"}, + ) + + provider.exchange_code = AsyncMock(return_value=identity_info) + provider.get_authorization_url = MagicMock(return_value="https://orcid.org/oauth/authorize?...") + return provider + + +class TestAuthServiceInitiateLogin: + """Tests for AuthService.initiate_login.""" + + @pytest.mark.asyncio + async def test_initiate_login_returns_authorization_url(self): + """initiate_login should return the provider's authorization URL.""" + service = make_auth_service() + provider = make_identity_provider() + provider.get_authorization_url.return_value = "https://orcid.org/oauth?state=abc" + + url = await service.initiate_login( + provider=provider, + state="test-state", + redirect_uri="http://localhost/callback", + ) + + assert url == "https://orcid.org/oauth?state=abc" + provider.get_authorization_url.assert_called_once_with( + "test-state", "http://localhost/callback" + ) + + +class TestAuthServiceCompleteOAuth: + """Tests for AuthService.complete_oauth.""" + + @pytest.mark.asyncio + async def test_complete_oauth_creates_new_user(self): + """complete_oauth should create user and identity for new user.""" + user_repo = AsyncMock() + user_repo.get.return_value = None # No existing user + + identity_repo = AsyncMock() + identity_repo.get_by_provider_and_external_id.return_value = None # No existing identity + + refresh_token_repo = AsyncMock() + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + provider = make_identity_provider() + + user, identity, access_token, refresh_token = await service.complete_oauth( + provider=provider, + code="auth-code", + redirect_uri="http://localhost/callback", + ) + + # Should create user and identity + user_repo.save.assert_called_once() + identity_repo.save.assert_called_once() + refresh_token_repo.save.assert_called_once() + + # Should return valid data + assert user.display_name == "Jane Doe" + assert identity.provider == "orcid" + assert identity.external_id == "0000-0001-2345-6789" + assert isinstance(access_token, str) + assert isinstance(refresh_token, str) + + @pytest.mark.asyncio + async def test_complete_oauth_returns_existing_user(self): + """complete_oauth should return existing user if identity exists.""" + existing_user = User( + id=UserId(uuid4()), + display_name="Existing User", + created_at=datetime.now(UTC), + updated_at=None, + ) + existing_identity = Identity( + id=IdentityId(uuid4()), + user_id=existing_user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + user_repo = AsyncMock() + user_repo.get.return_value = existing_user + + identity_repo = AsyncMock() + identity_repo.get_by_provider_and_external_id.return_value = existing_identity + + refresh_token_repo = AsyncMock() + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + provider = make_identity_provider() + + user, identity, _, _ = await service.complete_oauth( + provider=provider, + code="auth-code", + redirect_uri="http://localhost/callback", + ) + + # Should NOT create new user/identity + user_repo.save.assert_not_called() + identity_repo.save.assert_not_called() + + # Should return existing user + assert user.id == existing_user.id + assert identity.id == existing_identity.id + + +class TestAuthServiceRefreshTokens: + """Tests for AuthService.refresh_tokens.""" + + @pytest.mark.asyncio + async def test_refresh_tokens_issues_new_tokens(self): + """refresh_tokens should issue new access and refresh tokens.""" + user = User( + id=UserId(uuid4()), + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + old_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user.id, + token_hash="old-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + user_repo = AsyncMock() + user_repo.get.return_value = user + + identity_repo = AsyncMock() + identity_repo.get_by_user_id.return_value = [identity] + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = old_token + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + + returned_user, access_token, new_refresh_token = await service.refresh_tokens( + "raw-refresh-token" + ) + + # Should save new refresh token + assert refresh_token_repo.save.call_count == 2 # Once for revoking old, once for new + + # Should return valid data + assert returned_user.id == user.id + assert isinstance(access_token, str) + assert isinstance(new_refresh_token, str) + + @pytest.mark.asyncio + async def test_refresh_tokens_revokes_old_token(self): + """refresh_tokens should revoke the old refresh token.""" + user = User( + id=UserId(uuid4()), + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + old_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user.id, + token_hash="old-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + user_repo = AsyncMock() + user_repo.get.return_value = user + + identity_repo = AsyncMock() + identity_repo.get_by_user_id.return_value = [identity] + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = old_token + + service = make_auth_service( + user_repo=user_repo, + identity_repo=identity_repo, + refresh_token_repo=refresh_token_repo, + ) + + await service.refresh_tokens("raw-refresh-token") + + # The old token should be revoked + assert old_token.is_revoked is True + + @pytest.mark.asyncio + async def test_refresh_tokens_rejects_invalid_token(self): + """refresh_tokens should raise for unknown refresh token.""" + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = None + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("invalid-token") + + assert exc_info.value.code == "invalid_refresh_token" + + @pytest.mark.asyncio + async def test_refresh_tokens_detects_reuse_and_revokes_family(self): + """refresh_tokens should revoke entire family if revoked token is reused.""" + user_id = UserId(uuid4()) + family_id = TokenFamilyId(uuid4()) + + # Token that was already revoked (potential theft) + revoked_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=user_id, + token_hash="revoked-hash", + family_id=family_id, + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=datetime.now(UTC) - timedelta(hours=1), # Already revoked + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = revoked_token + refresh_token_repo.revoke_family.return_value = 3 # 3 tokens revoked + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("stolen-token") + + assert exc_info.value.code == "token_family_revoked" + + # Should revoke entire family + refresh_token_repo.revoke_family.assert_called_once_with(family_id) + + @pytest.mark.asyncio + async def test_refresh_tokens_rejects_expired_token(self): + """refresh_tokens should raise for expired refresh token.""" + expired_token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="expired-hash", + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) - timedelta(hours=1), # Expired + created_at=datetime.now(UTC) - timedelta(days=8), + revoked_at=None, + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = expired_token + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + with pytest.raises(InvalidStateError) as exc_info: + await service.refresh_tokens("expired-token") + + assert exc_info.value.code == "refresh_token_expired" + + +class TestAuthServiceLogout: + """Tests for AuthService.logout.""" + + @pytest.mark.asyncio + async def test_logout_revokes_token_family(self): + """logout should revoke the entire token family.""" + family_id = TokenFamilyId(uuid4()) + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="token-hash", + family_id=family_id, + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = token + refresh_token_repo.revoke_family.return_value = 1 + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + result = await service.logout("raw-refresh-token") + + assert result is True + refresh_token_repo.revoke_family.assert_called_once_with(family_id) + + @pytest.mark.asyncio + async def test_logout_succeeds_for_unknown_token(self): + """logout should succeed even if token is not found.""" + refresh_token_repo = AsyncMock() + refresh_token_repo.get_by_token_hash.return_value = None + + service = make_auth_service(refresh_token_repo=refresh_token_repo) + + result = await service.logout("unknown-token") + + assert result is True + refresh_token_repo.revoke_family.assert_not_called() diff --git a/server/tests/unit/domain/auth/test_command_handlers.py b/server/tests/unit/domain/auth/test_command_handlers.py new file mode 100644 index 0000000..9179e15 --- /dev/null +++ b/server/tests/unit/domain/auth/test_command_handlers.py @@ -0,0 +1,332 @@ +"""Unit tests for auth command handlers.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.command.login import ( + CompleteOAuth, + CompleteOAuthHandler, + InitiateLogin, + InitiateLoginHandler, +) +from osa.domain.auth.command.token import ( + Logout, + LogoutHandler, + RefreshTokens, + RefreshTokensHandler, +) +from osa.domain.auth.event import UserAuthenticated, UserLoggedOut +from osa.domain.auth.model.identity import Identity +from osa.domain.auth.model.user import User +from osa.domain.auth.model.value import IdentityId, UserId +from osa.domain.auth.service.token import TokenService + + +def make_token_service() -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + +def make_identity_provider() -> MagicMock: + """Create a mock identity provider.""" + provider = MagicMock() + provider.provider_name = "orcid" + provider.get_authorization_url = MagicMock( + return_value="https://orcid.org/oauth/authorize?state=xyz" + ) + return provider + + +def make_provider_registry(identity_provider: MagicMock | None = None) -> MagicMock: + """Create a mock provider registry.""" + if identity_provider is None: + identity_provider = make_identity_provider() + registry = MagicMock() + registry.get.return_value = identity_provider + registry.is_available.return_value = True + registry.available_providers.return_value = ["orcid"] + return registry + + +class TestInitiateLoginHandler: + """Tests for InitiateLoginHandler.""" + + @pytest.mark.asyncio + async def test_run_returns_authorization_url(self): + """Handler should return authorization URL from identity provider.""" + identity_provider = make_identity_provider() + provider_registry = make_provider_registry(identity_provider) + token_service = make_token_service() + + handler = InitiateLoginHandler( + provider_registry=provider_registry, + token_service=token_service, + ) + + result = await handler.run( + InitiateLogin( + callback_url="http://localhost/callback", + final_redirect_uri="http://localhost/dashboard", + provider="orcid", + ) + ) + + assert result.authorization_url == "https://orcid.org/oauth/authorize?state=xyz" + identity_provider.get_authorization_url.assert_called_once() + + @pytest.mark.asyncio + async def test_run_creates_signed_state(self): + """Handler should create signed state token with final redirect URI and provider.""" + identity_provider = make_identity_provider() + provider_registry = make_provider_registry(identity_provider) + token_service = make_token_service() + + handler = InitiateLoginHandler( + provider_registry=provider_registry, + token_service=token_service, + ) + + await handler.run( + InitiateLogin( + callback_url="http://localhost/callback", + final_redirect_uri="http://localhost/dashboard", + provider="orcid", + ) + ) + + # Verify state was passed to identity provider + call_args = identity_provider.get_authorization_url.call_args + state = call_args[1]["state"] if "state" in call_args[1] else call_args[0][0] + + # Verify state can be decoded to get back the redirect URI and provider + result = token_service.verify_oauth_state(state) + assert result is not None + redirect_uri, provider = result + assert redirect_uri == "http://localhost/dashboard" + assert provider == "orcid" + + +class TestCompleteOAuthHandler: + """Tests for CompleteOAuthHandler.""" + + @pytest.mark.asyncio + async def test_run_emits_user_authenticated_event(self): + """Handler should emit UserAuthenticated event on successful OAuth.""" + user = User( + id=UserId(uuid4()), + display_name="Jane Doe", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + auth_service = AsyncMock() + auth_service.complete_oauth.return_value = ( + user, + identity, + "access-token", + "refresh-token", + ) + + provider_registry = make_provider_registry() + token_service = make_token_service() + outbox = AsyncMock() + + handler = CompleteOAuthHandler( + auth_service=auth_service, + provider_registry=provider_registry, + token_service=token_service, + outbox=outbox, + ) + + await handler.run( + CompleteOAuth( + code="auth-code", + callback_url="http://localhost/callback", + provider="orcid", + ) + ) + + # Verify UserAuthenticated event was emitted + outbox.append.assert_called_once() + event = outbox.append.call_args[0][0] + assert isinstance(event, UserAuthenticated) + assert event.user_id == str(user.id) + assert event.provider == "orcid" + assert event.external_id == "0000-0001-2345-6789" + + @pytest.mark.asyncio + async def test_run_returns_user_info_and_tokens(self): + """Handler should return user info and tokens.""" + user = User( + id=UserId(uuid4()), + display_name="Jane Doe", + created_at=datetime.now(UTC), + updated_at=None, + ) + identity = Identity( + id=IdentityId(uuid4()), + user_id=user.id, + provider="orcid", + external_id="0000-0001-2345-6789", + metadata=None, + created_at=datetime.now(UTC), + ) + + auth_service = AsyncMock() + auth_service.complete_oauth.return_value = ( + user, + identity, + "access-token", + "refresh-token", + ) + + provider_registry = make_provider_registry() + token_service = make_token_service() + outbox = AsyncMock() + + handler = CompleteOAuthHandler( + auth_service=auth_service, + provider_registry=provider_registry, + token_service=token_service, + outbox=outbox, + ) + + result = await handler.run( + CompleteOAuth( + code="auth-code", + callback_url="http://localhost/callback", + provider="orcid", + ) + ) + + assert result.user_id == str(user.id) + assert result.display_name == "Jane Doe" + assert result.provider == "orcid" + assert result.external_id == "0000-0001-2345-6789" + assert result.access_token == "access-token" + assert result.refresh_token == "refresh-token" + assert result.expires_in == 60 * 60 # 60 minutes in seconds + + +class TestRefreshTokensHandler: + """Tests for RefreshTokensHandler.""" + + @pytest.mark.asyncio + async def test_run_returns_new_tokens(self): + """Handler should return new tokens from auth service.""" + user = User( + id=UserId(uuid4()), + display_name="Test User", + created_at=datetime.now(UTC), + updated_at=None, + ) + + auth_service = AsyncMock() + auth_service.refresh_tokens.return_value = ( + user, + "new-access-token", + "new-refresh-token", + ) + + token_service = make_token_service() + + handler = RefreshTokensHandler( + auth_service=auth_service, + token_service=token_service, + ) + + result = await handler.run(RefreshTokens(refresh_token="old-refresh-token")) + + assert result.access_token == "new-access-token" + assert result.refresh_token == "new-refresh-token" + assert result.expires_in == 60 * 60 # 60 minutes in seconds + + auth_service.refresh_tokens.assert_called_once_with("old-refresh-token") + + +class TestLogoutHandler: + """Tests for LogoutHandler.""" + + @pytest.mark.asyncio + async def test_run_emits_user_logged_out_event(self): + """Handler should emit UserLoggedOut event when user has valid token.""" + user_id = UserId(uuid4()) + + auth_service = AsyncMock() + auth_service.get_user_id_from_refresh_token.return_value = user_id + auth_service.logout.return_value = True + + outbox = AsyncMock() + + handler = LogoutHandler( + auth_service=auth_service, + outbox=outbox, + ) + + result = await handler.run(Logout(refresh_token="refresh-token")) + + assert result.success is True + + # Verify UserLoggedOut event was emitted + outbox.append.assert_called_once() + event = outbox.append.call_args[0][0] + assert isinstance(event, UserLoggedOut) + assert event.user_id == str(user_id) + + @pytest.mark.asyncio + async def test_run_does_not_emit_event_for_unknown_token(self): + """Handler should not emit event if token is unknown.""" + auth_service = AsyncMock() + auth_service.get_user_id_from_refresh_token.return_value = None + auth_service.logout.return_value = True + + outbox = AsyncMock() + + handler = LogoutHandler( + auth_service=auth_service, + outbox=outbox, + ) + + result = await handler.run(Logout(refresh_token="unknown-token")) + + assert result.success is True + + # Should NOT emit event for unknown token + outbox.append.assert_not_called() + + @pytest.mark.asyncio + async def test_run_returns_success(self): + """Handler should return success status from auth service.""" + auth_service = AsyncMock() + auth_service.get_user_id_from_refresh_token.return_value = UserId(uuid4()) + auth_service.logout.return_value = True + + outbox = AsyncMock() + + handler = LogoutHandler( + auth_service=auth_service, + outbox=outbox, + ) + + result = await handler.run(Logout(refresh_token="refresh-token")) + + assert result.success is True + auth_service.logout.assert_called_once_with("refresh-token") diff --git a/server/tests/unit/domain/auth/test_refresh_token.py b/server/tests/unit/domain/auth/test_refresh_token.py new file mode 100644 index 0000000..6179fbf --- /dev/null +++ b/server/tests/unit/domain/auth/test_refresh_token.py @@ -0,0 +1,227 @@ +"""Unit tests for RefreshToken entity.""" + +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + + +from osa.domain.auth.model.token import RefreshToken +from osa.domain.auth.model.value import RefreshTokenId, TokenFamilyId, UserId + + +class TestRefreshTokenCreate: + """Tests for RefreshToken.create factory method.""" + + def test_create_sets_all_fields(self): + """create should set all required fields.""" + user_id = UserId(uuid4()) + token_hash = "a" * 64 + family_id = TokenFamilyId(uuid4()) + + token = RefreshToken.create( + user_id=user_id, + token_hash=token_hash, + family_id=family_id, + expires_in_days=7, + ) + + assert token.user_id == user_id + assert token.token_hash == token_hash + assert token.family_id == family_id + assert token.revoked_at is None + + def test_create_generates_id(self): + """create should generate a unique ID.""" + user_id = UserId(uuid4()) + + token1 = RefreshToken.create(user_id, "a" * 64, TokenFamilyId(uuid4())) + token2 = RefreshToken.create(user_id, "b" * 64, TokenFamilyId(uuid4())) + + assert token1.id != token2.id + + def test_create_sets_expiry_in_future(self): + """create should set expires_at in the future.""" + user_id = UserId(uuid4()) + now = datetime.now(UTC) + + token = RefreshToken.create( + user_id=user_id, + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_in_days=7, + ) + + assert token.expires_at > now + # Should be roughly 7 days from now (allowing small margin) + expected = now + timedelta(days=7) + assert abs((token.expires_at - expected).total_seconds()) < 1 + + def test_create_sets_created_at(self): + """create should set created_at to current time.""" + now = datetime.now(UTC) + + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + assert abs((token.created_at - now).total_seconds()) < 1 + + +class TestRefreshTokenIsValid: + """Tests for RefreshToken.is_valid property.""" + + def make_token( + self, + expires_at: datetime | None = None, + revoked_at: datetime | None = None, + ) -> RefreshToken: + """Create a token with specified expiry and revocation.""" + if expires_at is None: + expires_at = datetime.now(UTC) + timedelta(days=7) + + return RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=expires_at, + created_at=datetime.now(UTC), + revoked_at=revoked_at, + ) + + def test_is_valid_true_for_fresh_token(self): + """is_valid should be True for non-expired, non-revoked token.""" + token = self.make_token() + + assert token.is_valid is True + + def test_is_valid_false_when_expired(self): + """is_valid should be False when token is expired.""" + expired_at = datetime.now(UTC) - timedelta(hours=1) + token = self.make_token(expires_at=expired_at) + + assert token.is_valid is False + + def test_is_valid_false_when_revoked(self): + """is_valid should be False when token is revoked.""" + token = self.make_token(revoked_at=datetime.now(UTC)) + + assert token.is_valid is False + + def test_is_valid_false_when_both_expired_and_revoked(self): + """is_valid should be False when both expired and revoked.""" + token = self.make_token( + expires_at=datetime.now(UTC) - timedelta(hours=1), + revoked_at=datetime.now(UTC) - timedelta(hours=2), + ) + + assert token.is_valid is False + + +class TestRefreshTokenIsRevoked: + """Tests for RefreshToken.is_revoked property.""" + + def test_is_revoked_false_initially(self): + """is_revoked should be False when revoked_at is None.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + assert token.is_revoked is False + + def test_is_revoked_true_when_set(self): + """is_revoked should be True when revoked_at is set.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=datetime.now(UTC), + ) + + assert token.is_revoked is True + + +class TestRefreshTokenIsExpired: + """Tests for RefreshToken.is_expired property.""" + + def test_is_expired_false_for_future_expiry(self): + """is_expired should be False when expires_at is in the future.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) + timedelta(days=7), + created_at=datetime.now(UTC), + revoked_at=None, + ) + + assert token.is_expired is False + + def test_is_expired_true_for_past_expiry(self): + """is_expired should be True when expires_at is in the past.""" + token = RefreshToken( + id=RefreshTokenId(uuid4()), + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + expires_at=datetime.now(UTC) - timedelta(hours=1), + created_at=datetime.now(UTC) - timedelta(days=8), + revoked_at=None, + ) + + assert token.is_expired is True + + +class TestRefreshTokenRevoke: + """Tests for RefreshToken.revoke method.""" + + def test_revoke_sets_revoked_at(self): + """revoke should set revoked_at to current time.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + assert token.revoked_at is None + + now = datetime.now(UTC) + token.revoke() + + assert token.revoked_at is not None + assert abs((token.revoked_at - now).total_seconds()) < 1 + + def test_revoke_idempotent(self): + """revoke should not change revoked_at if already revoked.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + + token.revoke() + first_revoked_at = token.revoked_at + + # Second revoke should not change the timestamp + token.revoke() + + assert token.revoked_at == first_revoked_at + + def test_revoke_makes_token_invalid(self): + """revoke should make is_valid return False.""" + token = RefreshToken.create( + user_id=UserId(uuid4()), + token_hash="a" * 64, + family_id=TokenFamilyId(uuid4()), + ) + assert token.is_valid is True + + token.revoke() + + assert token.is_valid is False diff --git a/server/tests/unit/domain/auth/test_token_service.py b/server/tests/unit/domain/auth/test_token_service.py new file mode 100644 index 0000000..5776511 --- /dev/null +++ b/server/tests/unit/domain/auth/test_token_service.py @@ -0,0 +1,250 @@ +"""Unit tests for TokenService JWT creation and validation.""" + +import time +from uuid import uuid4 + +import jwt +import pytest + +from osa.config import JwtConfig +from osa.domain.auth.model.value import ProviderIdentity, UserId +from osa.domain.auth.service.token import TokenService + + +class TestTokenServiceAccessToken: + """Tests for JWT access token creation and validation.""" + + def make_service(self, secret: str = "test-secret-key-256-bits-long-xx") -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret=secret, + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + def test_create_access_token_returns_valid_jwt(self): + """create_access_token should return a decodable JWT.""" + service = self.make_service() + user_id = UserId(uuid4()) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token = service.create_access_token(user_id, identity) + + # Should be decodable + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert payload["sub"] == str(user_id) + assert payload["provider"] == "orcid" + assert payload["external_id"] == "0000-0001-2345-6789" + assert payload["aud"] == "authenticated" + + def test_create_access_token_includes_expiry(self): + """create_access_token should set exp claim.""" + service = self.make_service() + user_id = UserId(uuid4()) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token = service.create_access_token(user_id, identity) + + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert "exp" in payload + assert "iat" in payload + # Expiry should be ~60 minutes from now + assert payload["exp"] > payload["iat"] + assert payload["exp"] - payload["iat"] == 60 * 60 # 60 minutes in seconds + + def test_create_access_token_includes_jti(self): + """create_access_token should include unique jti claim.""" + service = self.make_service() + user_id = UserId(uuid4()) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token1 = service.create_access_token(user_id, identity) + token2 = service.create_access_token(user_id, identity) + + payload1 = jwt.decode( + token1, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + payload2 = jwt.decode( + token2, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + + assert "jti" in payload1 + assert "jti" in payload2 + assert payload1["jti"] != payload2["jti"] + + def test_create_access_token_with_additional_claims(self): + """create_access_token should include additional claims if provided.""" + service = self.make_service() + user_id = UserId(uuid4()) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token = service.create_access_token( + user_id, + identity, + additional_claims={"custom": "value"}, + ) + + payload = jwt.decode( + token, + "test-secret-key-256-bits-long-xx", + algorithms=["HS256"], + audience="authenticated", + ) + assert payload["custom"] == "value" + + def test_validate_access_token_returns_payload(self): + """validate_access_token should return decoded payload.""" + service = self.make_service() + user_id = UserId(uuid4()) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token = service.create_access_token(user_id, identity) + payload = service.validate_access_token(token) + + assert payload["sub"] == str(user_id) + assert payload["provider"] == "orcid" + assert payload["external_id"] == "0000-0001-2345-6789" + + def test_validate_access_token_rejects_invalid_token(self): + """validate_access_token should raise for invalid tokens.""" + service = self.make_service() + + with pytest.raises(jwt.InvalidTokenError): + service.validate_access_token("invalid-token") + + def test_validate_access_token_rejects_wrong_secret(self): + """validate_access_token should reject tokens signed with wrong secret.""" + service1 = self.make_service(secret="secret-one-that-is-long-enough!!") + service2 = self.make_service(secret="secret-two-that-is-long-enough!!") + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + token = service1.create_access_token(UserId(uuid4()), identity) + + with pytest.raises(jwt.InvalidTokenError): + service2.validate_access_token(token) + + def test_validate_access_token_rejects_expired_token(self): + """validate_access_token should reject expired tokens.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=0, # Immediate expiry + refresh_token_expire_days=7, + ) + service = TokenService(_config=config) + identity = ProviderIdentity(provider="orcid", external_id="0000-0001-2345-6789") + + # Create token that's already expired + token = service.create_access_token(UserId(uuid4()), identity) + + # Small delay to ensure expiry + time.sleep(0.1) + + with pytest.raises(jwt.ExpiredSignatureError): + service.validate_access_token(token) + + +class TestTokenServiceRefreshToken: + """Tests for opaque refresh token creation.""" + + def make_service(self) -> TokenService: + """Create a TokenService with test config.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=7, + ) + return TokenService(_config=config) + + def test_create_refresh_token_returns_tuple(self): + """create_refresh_token should return (raw_token, token_hash).""" + service = self.make_service() + + raw_token, token_hash = service.create_refresh_token() + + assert isinstance(raw_token, str) + assert isinstance(token_hash, str) + assert len(raw_token) > 0 + assert len(token_hash) == 64 # SHA256 hex = 64 chars + + def test_create_refresh_token_unique_each_time(self): + """create_refresh_token should generate unique tokens.""" + service = self.make_service() + + raw1, hash1 = service.create_refresh_token() + raw2, hash2 = service.create_refresh_token() + + assert raw1 != raw2 + assert hash1 != hash2 + + def test_hash_token_consistent(self): + """hash_token should produce consistent hashes.""" + raw_token = "test-token-value" + + hash1 = TokenService.hash_token(raw_token) + hash2 = TokenService.hash_token(raw_token) + + assert hash1 == hash2 + assert len(hash1) == 64 + + def test_hash_token_different_for_different_tokens(self): + """hash_token should produce different hashes for different tokens.""" + hash1 = TokenService.hash_token("token-one") + hash2 = TokenService.hash_token("token-two") + + assert hash1 != hash2 + + def test_created_refresh_token_hash_matches(self): + """The hash from create_refresh_token should match hash_token(raw).""" + service = self.make_service() + + raw_token, token_hash = service.create_refresh_token() + + assert TokenService.hash_token(raw_token) == token_hash + + +class TestTokenServiceProperties: + """Tests for TokenService property accessors.""" + + def test_access_token_expire_seconds(self): + """access_token_expire_seconds should convert minutes to seconds.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=30, + refresh_token_expire_days=7, + ) + service = TokenService(_config=config) + + assert service.access_token_expire_seconds == 30 * 60 + + def test_refresh_token_expire_days(self): + """refresh_token_expire_days should return configured value.""" + config = JwtConfig( + secret="test-secret-key-256-bits-long-xx", + algorithm="HS256", + access_token_expire_minutes=60, + refresh_token_expire_days=14, + ) + service = TokenService(_config=config) + + assert service.refresh_token_expire_days == 14 diff --git a/server/uv.lock b/server/uv.lock index e888c64..cdf84bd 100644 --- a/server/uv.lock +++ b/server/uv.lock @@ -1588,6 +1588,7 @@ dependencies = [ { name = "psycopg2-binary", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", extra = ["email"], marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic-settings", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyjwt", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "rich", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sentence-transformers", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "sqlalchemy", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1623,6 +1624,7 @@ requires-dist = [ { name = "psycopg2-binary", specifier = ">=2.9.11" }, { name = "pydantic", extras = ["email"], specifier = ">=2.12.4" }, { name = "pydantic-settings", specifier = ">=2.12.0" }, + { name = "pyjwt", specifier = ">=2.11.0" }, { name = "rich", specifier = ">=14.2.0" }, { name = "sentence-transformers", specifier = ">=5.2.0" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, @@ -1948,6 +1950,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyjwt" +version = "2.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, +] + [[package]] name = "pypika" version = "0.50.0" diff --git a/web/src/app/auth/callback/page.tsx b/web/src/app/auth/callback/page.tsx new file mode 100644 index 0000000..700ed82 --- /dev/null +++ b/web/src/app/auth/callback/page.tsx @@ -0,0 +1,86 @@ +'use client'; + +import { Suspense, useEffect, useRef } from 'react'; +import { useRouter, useSearchParams } from 'next/navigation'; +import Link from 'next/link'; +import { OSAClient, parseAuthCallback } from '@/lib/sdk'; + +function AuthCallbackContent() { + const router = useRouter(); + const searchParams = useSearchParams(); + const processedRef = useRef(false); + + // Check for error in URL search params + const urlError = searchParams.get('error'); + const errorDescription = searchParams.get('error_description'); + const error = urlError ? (errorDescription || urlError) : null; + + useEffect(() => { + // Prevent double processing in strict mode + if (processedRef.current || error) return; + + // Check for auth data in hash + const hash = window.location.hash; + if (!hash || !hash.includes('auth=')) { + return; + } + + // Parse and store auth data + const params = parseAuthCallback(hash); + if (!params) { + return; + } + + processedRef.current = true; + + // Store in client + const client = new OSAClient({ baseUrl: '/api/v1' }); + client.handleAuthCallback(hash); + + // Redirect to home (or wherever user came from) + router.push('/'); + }, [router, error]); + + if (error) { + return ( +
+

Authentication Error

+

{error}

+ + Return to Home + +
+ ); + } + + return ( +
+

Completing sign in...

+
+ ); +} + +export default function AuthCallbackPage() { + return ( + +

Loading...

+ + } + > + +
+ ); +} diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index dc5ed1a..15f9a7d 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -1,6 +1,7 @@ import type { Metadata } from 'next'; import { Geist, Geist_Mono } from 'next/font/google'; import './globals.css'; +import { AuthProvider } from '@/components/auth/AuthProvider'; import { Header } from '@/components/layout/Header'; import { Footer } from '@/components/layout/Footer'; @@ -22,6 +23,9 @@ export const metadata: Metadata = { keywords: ['biology', 'genomics', 'GEO', 'research', 'scientific data', 'semantic search'], }; +// API URL: use env var for dev (different port), relative path for prod (reverse proxy) +const apiBaseUrl = process.env.NEXT_PUBLIC_API_URL || '/api/v1'; + export default function RootLayout({ children, }: Readonly<{ @@ -30,9 +34,11 @@ export default function RootLayout({ return ( -
- {children} -