diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 72f8391..a27823c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -203,11 +203,15 @@ jobs: OSA_AUTH__JWT__SECRET: test-secret-for-integration-tests-minimum-32-chars TEST: "1" - # Build & push Docker image (only on main, gated on all server checks) + # Build & push Docker image (main + PRs onto main, gated on all server checks) image: name: Server - Image needs: [changes, server-lint, server-typecheck, server-test, server-contract, server-integration] - if: github.ref == 'refs/heads/main' && needs.changes.outputs.server == 'true' + if: >- + needs.changes.outputs.server == 'true' && ( + github.ref == 'refs/heads/main' || + (github.event_name == 'pull_request' && github.base_ref == 'main') + ) uses: ./.github/workflows/image.yml permissions: contents: read diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index 37efb00..e965840 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -35,6 +35,9 @@ jobs: if [[ "${{ github.event_name }}" == "release" ]]; then TAGS="${TAGS},${{ env.IMAGE }}:${{ github.event.release.tag_name }}" fi + if [[ -n "${{ github.event.pull_request.number }}" ]]; then + TAGS="${TAGS},${{ env.IMAGE }}:pr-${{ github.event.pull_request.number }}" + fi echo "tags=${TAGS}" >> "$GITHUB_OUTPUT" - uses: docker/build-push-action@v6 diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index ce5db40..d5ff008 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -20,6 +20,7 @@ services: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data OSA_CONFIG_FILE: /app/osa.yaml + OSA_BASE_URL: http://localhost:8000 OSA_LOGGING__LEVEL: ${LOG_LEVEL:-DEBUG} WATCHFILES_FORCE_POLLING: "true" entrypoint: [] diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 172633f..3be0e9a 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -21,6 +21,7 @@ services: environment: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data + OSA_BASE_URL: ${OSA_BASE_URL:-http://localhost:8000} OSA_LOGGING__LEVEL: ${LOG_LEVEL:-INFO} OSA_AUTH__JWT__SECRET: ${JWT_SECRET:-change-me-in-production-must-be-32-chars-long} depends_on: diff --git a/server/Justfile b/server/Justfile index 5310a3f..eb93bc9 100644 --- a/server/Justfile +++ b/server/Justfile @@ -36,6 +36,7 @@ test-cov: # Run linter and type checker lint: + @just fix uv run ruff check osa uv run ty check osa diff --git a/server/migrations/versions/add_ingest_runs.py b/server/migrations/versions/add_ingest_runs.py new file mode 100644 index 0000000..6a23b7c --- /dev/null +++ b/server/migrations/versions/add_ingest_runs.py @@ -0,0 +1,93 @@ +"""add_ingest_runs + +Add ingest_runs table for bulk ingestion tracking. + +Revision ID: add_ingest_runs +Revises: source_agnostic_records +Create Date: 2026-03-25 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_ingest_runs" +down_revision: Union[str, Sequence[str], None] = "source_agnostic_records" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "ingest_runs", + sa.Column("srn", sa.String(), primary_key=True), + sa.Column( + "convention_srn", + sa.String(), + sa.ForeignKey("conventions.srn"), + nullable=False, + ), + sa.Column( + "status", + sa.String(32), + nullable=False, + server_default=sa.text("'pending'"), + ), + sa.Column( + "ingestion_finished", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "batches_ingested", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "batches_completed", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "published_count", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "batch_size", + sa.Integer(), + nullable=False, + server_default=sa.text("1000"), + ), + sa.Column("record_limit", sa.Integer(), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed')", + name="ingest_runs_status_check", + ), + ) + + op.create_index( + "idx_ingest_runs_convention", + "ingest_runs", + ["convention_srn"], + ) + op.create_index( + "idx_ingest_runs_status", + "ingest_runs", + ["status"], + ) + + +def downgrade() -> None: + op.drop_index("idx_ingest_runs_status", table_name="ingest_runs") + op.drop_index("idx_ingest_runs_convention", table_name="ingest_runs") + op.drop_table("ingest_runs") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 4cbd273..0f41bc0 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -1,4 +1,5 @@ import logging +import sys from contextlib import asynccontextmanager from typing import Any @@ -16,6 +17,7 @@ depositions, discovery, events, + ingestions, health, ontologies, records, @@ -25,7 +27,7 @@ validation, ) from osa.application.di import create_container -from osa.config import Config, configure_logging +from osa.config import Config from osa.domain.shared.authorization.startup import validate_all_handlers from osa.domain.shared.error import OSAError from osa.domain.shared.event import EventHandler @@ -80,9 +82,41 @@ def create_app( # Pydantic Settings populates from env vars at runtime config = Config() # type: ignore[call-arg] - # Configure logging early - configure_logging(config.logging) - logger.info("Starting OSA server: %s v%s", config.name, config.version) + # Configure logfire as the single logging system + import logging as _logging + + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + from osa.infrastructure.logging import OSAConsoleExporter + + logfire.configure( + send_to_logfire="if-token-present", + service_name=config.name, + console=False, # Disable default console — we use OSAConsoleExporter + inspect_arguments=False, + additional_span_processors=[ + SimpleSpanProcessor( + OSAConsoleExporter( + output=sys.stderr, + include_timestamp=True, + min_log_level=config.logging.level, + ) + ), + ], + ) + + # Route Python logging through logfire so old-style logger.info() calls + # appear in the same output stream + root = _logging.getLogger() + root.setLevel(config.logging.level.upper()) + for h in root.handlers[:]: + root.removeHandler(h) + root.addHandler(logfire.LogfireLoggingHandler()) + + # Suppress duplicate access logs — logfire FastAPI instrumentation handles HTTP logging + _logging.getLogger("uvicorn.access").setLevel(_logging.WARNING) + + logfire.info("Starting OSA server: {name} v{version}", name=config.name, version=config.version) # Validate all handlers have authorization declarations (fail fast) validate_all_handlers() @@ -94,9 +128,11 @@ def create_app( lifespan=lifespan, ) - # Instrument FastAPI for automatic tracing of HTTP requests logfire.instrument_httpx() - logfire.instrument_fastapi(app_instance) + logfire.instrument_fastapi( + app_instance, + excluded_urls="/api/v1/health", + ) # Setup dependency injection container = create_container( @@ -117,6 +153,7 @@ def create_app( app_instance.include_router(schemas.router, prefix="/api/v1") app_instance.include_router(conventions.router, prefix="/api/v1") app_instance.include_router(depositions.router, prefix="/api/v1") + app_instance.include_router(ingestions.router, prefix="/api/v1") app_instance.include_router(validation.router, prefix="/api/v1") app_instance.include_router(discovery.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/routes/ingestions.py b/server/osa/application/api/v1/routes/ingestions.py new file mode 100644 index 0000000..86a6f50 --- /dev/null +++ b/server/osa/application/api/v1/routes/ingestions.py @@ -0,0 +1,20 @@ +"""Ingest REST routes.""" + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.domain.ingest.command.start_ingest import ( + IngestRunCreated, + StartIngest, + StartIngestHandler, +) + +router = APIRouter(prefix="/ingestions", tags=["Ingestions"], route_class=DishkaRoute) + + +@router.post("", response_model=IngestRunCreated, status_code=201) +async def start_ingest( + body: StartIngest, + handler: FromDishka[StartIngestHandler], +) -> IngestRunCreated: + return await handler.run(body) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 1a96763..f7e7ff3 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -17,7 +17,7 @@ from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.k8s.di import RunnerProvider from osa.infrastructure.persistence import PersistenceProvider -from osa.infrastructure.source.di import SourceProvider +from osa.infrastructure.ingest.di import IngestProvider from osa.util.di.scope import Scope from osa.util.paths import OSAPaths @@ -44,7 +44,7 @@ def create_container( PersistenceProvider(), RunnerProvider(), IndexProvider(), - SourceProvider(), + IngestProvider(), EventProvider(extra_handlers=extra_handlers), HttpProvider(), DepositionProvider(), diff --git a/server/osa/config.py b/server/osa/config.py index bf399c8..ffeda85 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -1,12 +1,13 @@ +from logfire import LevelName import logging import os import re import sys from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Annotated import yaml -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, BeforeValidator, field_validator, model_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource from typing_extensions import Self @@ -63,7 +64,9 @@ class DatabaseConfig(BaseModel): class LoggingConfig(BaseModel): """Logging configuration (nested in Config, uses env_nested_delimiter).""" - level: str = "DEBUG" # Root log level (DEBUG for development) + level: Annotated[ + LevelName, BeforeValidator(lambda v: v.lower() if isinstance(v, str) else v) + ] = "debug" # Root log level (DEBUG for development) format: str = "%(asctime)s %(levelname)-8s [%(name)s] %(message)s" date_format: str = "%Y-%m-%d %H:%M:%S" @@ -81,6 +84,7 @@ class WorkerConfig(BaseModel): poll_interval: float = 0.5 # Seconds between outbox polls batch_size: int = 100 # Maximum events to fetch per poll cycle + hook_concurrency: int = 8 # Number of concurrent hook workers class K8sConfig(BaseModel): @@ -225,7 +229,8 @@ class Config(BaseSettings): name: str = "Open Science Archive" version: str = "0.0.1" description: str = "An open platform for depositing scientific data" - domain: str = "localhost" # Node domain for SRN construction + domain: str = "localhost" # Node identity for SRN construction (DNS name) + base_url: str = "" # Public URL where users reach the server (e.g. http://localhost:8000) # These are BaseModel, so env_nested_delimiter handles their env vars frontend: Frontend = Frontend() @@ -243,15 +248,26 @@ class Config(BaseSettings): "env_nested_delimiter": "__", # Allows OSA_DATABASE__URL override } - @property - def base_url(self) -> str: - """Public base URL derived from domain. HTTPS unless localhost.""" - scheme = "http" if self.domain == "localhost" else "https" - return f"{scheme}://{self.domain}" + @model_validator(mode="after") + def derive_base_url(self) -> Self: + """Derive base_url from domain if not explicitly set. + + For non-localhost domains, HTTPS on port 443 is assumed. + For localhost, base_url must be set explicitly (port matters). + """ + if self.base_url: + return self + if self.domain == "localhost": + raise ValueError( + "OSA_BASE_URL is required when domain is localhost " + "(e.g. OSA_BASE_URL=http://localhost:8000)" + ) + self.base_url = f"https://{self.domain}" + return self @model_validator(mode="after") def derive_frontend_url(self) -> Self: - """Derive frontend URL from domain if still the default localhost value.""" + """Derive frontend URL from base_url if still the default localhost value.""" if self.frontend.url == "http://localhost:3000": self.frontend = Frontend(url=self.base_url) return self @@ -366,5 +382,6 @@ def configure_logging(config: LoggingConfig) -> None: logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("aiosqlite").setLevel(logging.WARNING) logging.getLogger("apscheduler").setLevel(logging.WARNING) # Suppress job completion spam + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) # Logfire handles HTTP logging logging.debug("Logging configured: level=%s, file=%s", config.level, config.file) diff --git a/server/osa/domain/deposition/command/create_convention.py b/server/osa/domain/deposition/command/create_convention.py index 2e35dc8..c50059b 100644 --- a/server/osa/domain/deposition/command/create_convention.py +++ b/server/osa/domain/deposition/command/create_convention.py @@ -8,7 +8,7 @@ from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN @@ -21,7 +21,7 @@ class CreateConvention(Command): file_requirements: FileRequirements description: str | None = None hooks: list[HookDefinition] = [] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None class ConventionCreated(Result): @@ -44,7 +44,7 @@ async def run(self, cmd: CreateConvention) -> ConventionCreated: file_requirements=cmd.file_requirements, description=cmd.description, hooks=cmd.hooks, - source=cmd.source, + ingester=cmd.ingester, ) return ConventionCreated( srn=convention.srn, diff --git a/server/osa/domain/deposition/handler/__init__.py b/server/osa/domain/deposition/handler/__init__.py index b1218b2..16ec113 100644 --- a/server/osa/domain/deposition/handler/__init__.py +++ b/server/osa/domain/deposition/handler/__init__.py @@ -1,8 +1,5 @@ """Deposition domain event handlers.""" -from osa.domain.deposition.handler.create_deposition_from_source import ( - CreateDepositionFromSource, -) from osa.domain.deposition.handler.return_to_draft import ReturnToDraft -__all__ = ["CreateDepositionFromSource", "ReturnToDraft"] +__all__ = ["ReturnToDraft"] diff --git a/server/osa/domain/deposition/handler/create_deposition_from_source.py b/server/osa/domain/deposition/handler/create_deposition_from_source.py deleted file mode 100644 index 769e34a..0000000 --- a/server/osa/domain/deposition/handler/create_deposition_from_source.py +++ /dev/null @@ -1,50 +0,0 @@ -"""CreateDepositionFromSource — creates a deposition from a source record.""" - -import logging -from pathlib import Path - -from osa.domain.auth.model.value import SYSTEM_USER_ID -from osa.domain.deposition.port.storage import FileStoragePort -from osa.domain.deposition.service.deposition import DepositionService -from osa.domain.shared.event import EventHandler -from osa.domain.source.event.source_record_ready import SourceRecordReady - -logger = logging.getLogger(__name__) - - -class CreateDepositionFromSource(EventHandler[SourceRecordReady]): - """Creates a deposition when a source record is ready. - - Replaces the direct DepositionService calls that used to live - in SourceService — now the source domain only emits events and - the deposition domain reacts. - """ - - deposition_service: DepositionService - file_storage: FileStoragePort - - async def handle(self, event: SourceRecordReady) -> None: - """Create deposition, set metadata, move files, and submit.""" - dep = await self.deposition_service.create( - convention_srn=event.convention_srn, - owner_id=SYSTEM_USER_ID, - ) - - await self.deposition_service.update_metadata( - srn=dep.srn, - metadata=event.metadata, - ) - - await self.file_storage.move_source_files_to_deposition( - staging_dir=Path(event.staging_dir), - source_id=event.source_id, - deposition_srn=dep.srn, - ) - - await self.deposition_service.submit(srn=dep.srn) - - logger.info( - "Created deposition %s from source record %s", - dep.srn, - event.source_id, - ) diff --git a/server/osa/domain/deposition/model/convention.py b/server/osa/domain/deposition/model/convention.py index c1a2508..63c1e77 100644 --- a/server/osa/domain/deposition/model/convention.py +++ b/server/osa/domain/deposition/model/convention.py @@ -3,7 +3,7 @@ from osa.domain.deposition.model.value import FileRequirements from osa.domain.shared.model.aggregate import Aggregate from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN @@ -16,5 +16,5 @@ class Convention(Aggregate): schema_srn: SchemaSRN file_requirements: FileRequirements hooks: list[HookDefinition] = [] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None created_at: datetime diff --git a/server/osa/domain/deposition/port/storage.py b/server/osa/domain/deposition/port/storage.py index ad8f7be..7476c48 100644 --- a/server/osa/domain/deposition/port/storage.py +++ b/server/osa/domain/deposition/port/storage.py @@ -11,9 +11,8 @@ class FileStoragePort(Port, Protocol): """Storage operations scoped to the deposition domain. - Hook output, hook features, and source staging methods have been - moved to their respective domain ports (HookStoragePort, - FeatureStoragePort, SourceStoragePort). + Hook output and hook features methods have been moved to their + respective domain ports (HookStoragePort, FeatureStoragePort). """ @abstractmethod diff --git a/server/osa/domain/deposition/query/get_convention.py b/server/osa/domain/deposition/query/get_convention.py index 27fcb4a..b39e467 100644 --- a/server/osa/domain/deposition/query/get_convention.py +++ b/server/osa/domain/deposition/query/get_convention.py @@ -4,7 +4,7 @@ from osa.domain.deposition.service.convention import ConventionService from osa.domain.shared.authorization.gate import public from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.domain.shared.query import Query, QueryHandler, Result @@ -20,7 +20,7 @@ class ConventionDetail(Result): schema_srn: SchemaSRN file_requirements: FileRequirements hooks: list[HookDefinition] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None created_at: datetime @@ -37,6 +37,6 @@ async def run(self, cmd: GetConvention) -> ConventionDetail: schema_srn=conv.schema_srn, file_requirements=conv.file_requirements, hooks=conv.hooks, - source=conv.source, + ingester=conv.ingester, created_at=conv.created_at, ) diff --git a/server/osa/domain/deposition/service/convention.py b/server/osa/domain/deposition/service/convention.py index 3c2951d..79492e9 100644 --- a/server/osa/domain/deposition/service/convention.py +++ b/server/osa/domain/deposition/service/convention.py @@ -10,7 +10,7 @@ from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventId from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, Domain, LocalId, Semver from osa.domain.shared.outbox import Outbox from osa.domain.shared.service import Service @@ -30,7 +30,7 @@ async def create_convention( file_requirements: FileRequirements, description: str | None = None, hooks: list[HookDefinition] | None = None, - source: SourceDefinition | None = None, + ingester: IngesterDefinition | None = None, ) -> Convention: """Create a convention with an inline schema. @@ -59,7 +59,7 @@ async def create_convention( schema_srn=created_schema.srn, file_requirements=file_requirements, hooks=hooks or [], - source=source, + ingester=ingester, created_at=datetime.now(UTC), ) diff --git a/server/osa/domain/feature/event/convention_ready.py b/server/osa/domain/feature/event/convention_ready.py index 1c7d7d0..42b627e 100644 --- a/server/osa/domain/feature/event/convention_ready.py +++ b/server/osa/domain/feature/event/convention_ready.py @@ -7,8 +7,7 @@ class ConventionReady(Event): """Emitted when feature tables have been created for a convention. - Downstream handlers (e.g. TriggerInitialSourceRun) react to this - to kick off initial source runs, knowing that feature tables are ready. + Downstream handlers react to this knowing that feature tables are ready. """ id: EventId diff --git a/server/osa/domain/feature/handler/__init__.py b/server/osa/domain/feature/handler/__init__.py index b9db31d..542e540 100644 --- a/server/osa/domain/feature/handler/__init__.py +++ b/server/osa/domain/feature/handler/__init__.py @@ -1,6 +1,7 @@ """Feature domain event handlers.""" from osa.domain.feature.handler.create_feature_tables import CreateFeatureTables +from osa.domain.feature.handler.insert_batch_features import InsertBatchFeatures from osa.domain.feature.handler.insert_record_features import InsertRecordFeatures -__all__ = ["CreateFeatureTables", "InsertRecordFeatures"] +__all__ = ["CreateFeatureTables", "InsertBatchFeatures", "InsertRecordFeatures"] diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py new file mode 100644 index 0000000..cf7f366 --- /dev/null +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -0,0 +1,72 @@ +"""InsertBatchFeatures — bulk feature insertion for ingest batches.""" + +from osa.domain.feature.port.storage import FeatureStoragePort +from osa.domain.feature.service.feature import FeatureService +from osa.domain.ingest.event.events import IngestBatchPublished +from osa.domain.shared.event import EventHandler +from osa.infrastructure.logging import get_logger +from osa.infrastructure.storage.layout import StorageLayout + +log = get_logger(__name__) + + +class InsertBatchFeatures(EventHandler[IngestBatchPublished]): + """Reads hook outputs for an ingest batch and inserts features in bulk. + + Handles IngestBatchPublished (batch-level event) rather than + per-record RecordPublished. Uses read_batch_outcomes to parse + the JSONL output format (not the single-record features.json). + """ + + feature_service: FeatureService + feature_storage: FeatureStoragePort + layout: StorageLayout + + async def handle(self, event: IngestBatchPublished) -> None: + if not event.expected_features or not event.published_srns: + return + + batch_output_dir = str( + self.layout.ingest_batch_dir(event.ingest_run_srn, event.batch_index) + ) + + total_inserted = 0 + skipped_dupes = 0 + + for hook_name in event.expected_features: + # Read JSONL outcomes for this hook + outcomes = await self.feature_storage.read_batch_outcomes(batch_output_dir, hook_name) + + # Insert features for each published record that passed this hook. + # Map upstream source ID → published record SRN so features + # are keyed by the record SRN, not the upstream ID. + for upstream_id, outcome in outcomes.items(): + if outcome.status != "passed" or not outcome.features: + continue + + record_srn = event.upstream_to_record_srn.get(upstream_id) + if not record_srn: + # Expected for cross-batch duplicates — the record was already + # published in an earlier batch, so ON CONFLICT DO NOTHING + # skipped it and features were already inserted then. + skipped_dupes += 1 + continue + + count = await self.feature_service.insert_features( + hook_name=hook_name, + record_srn=record_srn, + rows=outcome.features, + ) + total_inserted += count + + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + dupe_msg = f", {skipped_dupes} duplicates skipped" if skipped_dupes else "" + log.info( + "[{short_id}] batch {batch_index}: inserted {total_inserted} feature rows ({hook_count} hooks{dupe_msg})", + short_id=short_id, + batch_index=event.batch_index, + total_inserted=total_inserted, + hook_count=len(event.expected_features), + dupe_msg=dupe_msg, + ingest_run_srn=event.ingest_run_srn, + ) diff --git a/server/osa/domain/feature/port/storage.py b/server/osa/domain/feature/port/storage.py index 4590de5..64caa81 100644 --- a/server/osa/domain/feature/port/storage.py +++ b/server/osa/domain/feature/port/storage.py @@ -4,6 +4,7 @@ from typing import Any, Protocol from osa.domain.shared.port import Port +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome, HookRecordId class FeatureStoragePort(Port, Protocol): @@ -29,3 +30,16 @@ async def read_hook_features( async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> bool: """Check whether features.json exists in a hook's output directory.""" ... + + @abstractmethod + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + """Read JSONL batch outputs (features/rejections/errors) for a hook. + + Parses features.jsonl, rejections.jsonl, and errors.jsonl from the + hook's output directory. Each record appears in exactly one file. + + Returns a dict keyed by record ID. + """ + ... diff --git a/server/osa/domain/source/port/__init__.py b/server/osa/domain/ingest/__init__.py similarity index 100% rename from server/osa/domain/source/port/__init__.py rename to server/osa/domain/ingest/__init__.py diff --git a/server/tests/unit/domain/source/__init__.py b/server/osa/domain/ingest/command/__init__.py similarity index 100% rename from server/tests/unit/domain/source/__init__.py rename to server/osa/domain/ingest/command/__init__.py diff --git a/server/osa/domain/ingest/command/start_ingest.py b/server/osa/domain/ingest/command/start_ingest.py new file mode 100644 index 0000000..1b65bbd --- /dev/null +++ b/server/osa/domain/ingest/command/start_ingest.py @@ -0,0 +1,47 @@ +"""StartIngest command — initiates a bulk ingestion run for a convention.""" + +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.command import Command, CommandHandler, Result + + +class StartIngest(Command): + """Start an ingest run for a convention.""" + + convention_srn: str + batch_size: int = 1000 + limit: int | None = None # Max total records to ingest (None = unlimited) + + +class IngestRunCreated(Result): + """Result of starting an ingest run.""" + + srn: str + convention_srn: str + status: str + started_at: str + + +class StartIngestHandler(CommandHandler[StartIngest, IngestRunCreated]): + """Thin command handler — delegates to IngestService.""" + + __auth__ = at_least(Role.ADMIN) + + from osa.domain.auth.model.principal import Principal + from osa.domain.ingest.service.ingest import IngestService + + principal: Principal + service: IngestService + + async def run(self, cmd: StartIngest) -> IngestRunCreated: + ingest_run = await self.service.start_ingest( + convention_srn=cmd.convention_srn, + batch_size=cmd.batch_size, + limit=cmd.limit, + ) + return IngestRunCreated( + srn=ingest_run.srn, + convention_srn=ingest_run.convention_srn, + status=ingest_run.status, + started_at=ingest_run.started_at.isoformat(), + ) diff --git a/server/osa/domain/ingest/event/__init__.py b/server/osa/domain/ingest/event/__init__.py new file mode 100644 index 0000000..c8b3101 --- /dev/null +++ b/server/osa/domain/ingest/event/__init__.py @@ -0,0 +1,17 @@ +"""Ingest domain events.""" + +from osa.domain.ingest.event.events import ( + HookBatchCompleted, + IngestBatchPublished, + IngestCompleted, + IngestStarted, + IngesterBatchReady, +) + +__all__ = [ + "IngestStarted", + "IngesterBatchReady", + "HookBatchCompleted", + "IngestBatchPublished", + "IngestCompleted", +] diff --git a/server/osa/domain/ingest/event/events.py b/server/osa/domain/ingest/event/events.py new file mode 100644 index 0000000..ee0eaab --- /dev/null +++ b/server/osa/domain/ingest/event/events.py @@ -0,0 +1,60 @@ +"""Ingest domain events — payloads carry path references, not inline data (AD-1).""" + +from osa.domain.shared.event import Event, EventId + + +class IngestStarted(Event): + """Emitted when an ingest run is created. Triggers first ingester pull.""" + + id: EventId + ingest_run_srn: str + convention_srn: str + batch_size: int + + +class IngesterBatchReady(Event): + """Emitted when an ingester container produces a batch of records. + + Batch data is on disk at the path derived from {ingest_run_srn, batch_index}. + """ + + id: EventId + ingest_run_srn: str + batch_index: int + has_more: bool + + +class HookBatchCompleted(Event): + """Emitted when hook processing completes for a batch. + + Outcomes (features/rejections/errors) are on disk at the batch output path. + """ + + id: EventId + ingest_run_srn: str + batch_index: int + + +class IngestBatchPublished(Event): + """Emitted when records from a batch are bulk-published. + + Triggers InsertBatchFeatures for feature insertion. + Batch-level event — NOT per-record (AD-3). + """ + + id: EventId + ingest_run_srn: str + convention_srn: str + batch_index: int + published_srns: list[str] + published_count: int + expected_features: list[str] + upstream_to_record_srn: dict[str, str] # upstream source ID → published record SRN + + +class IngestCompleted(Event): + """Emitted when all batches are processed and the ingest run is complete.""" + + id: EventId + ingest_run_srn: str + total_published: int diff --git a/server/osa/domain/ingest/handler/__init__.py b/server/osa/domain/ingest/handler/__init__.py new file mode 100644 index 0000000..b294305 --- /dev/null +++ b/server/osa/domain/ingest/handler/__init__.py @@ -0,0 +1,7 @@ +"""Ingest domain event handlers.""" + +from osa.domain.ingest.handler.publish_batch import PublishBatch +from osa.domain.ingest.handler.run_hooks import RunHooks +from osa.domain.ingest.handler.run_ingester import RunIngester + +__all__ = ["RunIngester", "RunHooks", "PublishBatch"] diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py new file mode 100644 index 0000000..a363f70 --- /dev/null +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -0,0 +1,207 @@ +"""PublishBatch — reads hook outputs, bulk-publishes passing records.""" + +from datetime import UTC, datetime +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.feature.port.storage import FeatureStoragePort +from osa.domain.ingest.event.events import ( + HookBatchCompleted, + IngestBatchPublished, + IngestCompleted, +) +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.model.ingester_record import IngesterRecord +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.record.model.draft import RecordDraft +from osa.domain.record.service import RecordService +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.source import IngestSource +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.infrastructure.logging import get_logger +from osa.infrastructure.storage.layout import StorageLayout + +log = get_logger(__name__) + + +class PublishBatch(EventHandler[HookBatchCompleted]): + """Reads hook outputs, constructs RecordDrafts, bulk-publishes passing records.""" + + ingest_repo: IngestRunRepository + convention_service: ConventionService + record_service: RecordService + feature_storage: FeatureStoragePort + outbox: Outbox + layout: StorageLayout + + async def handle(self, event: HookBatchCompleted) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(ingest_run.convention_srn) + ) + + # Read ingester records from batch dir + batch_dir = self.layout.ingest_batch_dir(event.ingest_run_srn, event.batch_index) + ingester_dir = self.layout.ingest_batch_ingester_dir( + event.ingest_run_srn, event.batch_index + ) + ingester_records = IngesterRecord.from_jsonl(ingester_dir / "records.jsonl") + + # Read hook outcomes for all hooks + expected_features = [h.name for h in convention.hooks] + + # Determine which records passed all hooks (via storage port — works on filesystem + S3) + # TODO: is this efficient, are we hitting S3 a lot? + passed_records = await _get_passed_records( + ingester_records=ingester_records, + batch_dir=str(batch_dir), + hooks=expected_features, + feature_storage=self.feature_storage, + ) + + # Log outcome breakdown per hook + total = len(ingester_records) + for hook_name in expected_features: + outcomes = await self.feature_storage.read_batch_outcomes(str(batch_dir), hook_name) + from osa.domain.validation.model.batch_outcome import OutcomeStatus + + passed = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.PASSED) + rejected = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.REJECTED) + errored = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.ERRORED) + missing = total - len(outcomes) + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] batch {batch_index} hook={hook_name}: " + "{passed}/{total} passed, {rejected} rejected, {errored} errored, {missing} missing", + short_id=short_id, + batch_index=event.batch_index, + hook_name=hook_name, + total=total, + passed=passed, + rejected=rejected, + errored=errored, + missing=missing, + ingest_run_srn=event.ingest_run_srn, + ) + + published_count = 0 + if passed_records: + # Construct RecordDrafts + drafts: list[RecordDraft] = [] + for record in passed_records: + drafts.append( + RecordDraft( + source=IngestSource( + id=f"{ingest_run.convention_srn}:{record.source_id}", + ingest_run_srn=ingest_run.srn, + upstream_source=record.source_id, + ), + metadata=record.metadata, + convention_srn=ConventionSRN.parse(ingest_run.convention_srn), + expected_features=expected_features, + ) + ) + + # Bulk publish — ON CONFLICT DO NOTHING skips duplicates + published = await self.record_service.bulk_publish(drafts) + published_srns = [str(r.srn) for r in published] + published_count = len(published) + + # Build upstream ID → record SRN mapping for feature insertion + upstream_to_record_srn: dict[str, str] = {} + for record in published: + # TODO: should we make RecordDraft generic over source type so we don't have to check this at runtime? + if not isinstance(record.source, IngestSource): + log.warn( + "Skipping record with unsupported source type: {source_type}", + source_type=type(record.source).__name__, + ) + continue + upstream_to_record_srn[record.source.upstream_source] = str(record.srn) + + log.info( + "[{short_id}] batch {batch_index}: published {published}/{passed} records ({duplicates} duplicates skipped)", + short_id=short_id, + batch_index=event.batch_index, + published=published_count, + passed=len(passed_records), + duplicates=len(drafts) - published_count, + ingest_run_srn=event.ingest_run_srn, + ) + + # Emit IngestBatchPublished for feature insertion + if published_count > 0: + await self.outbox.append( + IngestBatchPublished( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + convention_srn=ingest_run.convention_srn, + batch_index=event.batch_index, + published_srns=published_srns, + published_count=published_count, + expected_features=expected_features, + upstream_to_record_srn=upstream_to_record_srn, + ) + ) + + # Update counters atomically + updated = await self.ingest_repo.increment_completed( + event.ingest_run_srn, + published_count=published_count, + ) + + # Check completion condition + if updated.is_complete and updated.status == IngestStatus.RUNNING: + updated.check_completion(datetime.now(UTC)) + await self.ingest_repo.save(updated) + + await self.outbox.append( + IngestCompleted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + total_published=updated.published_count, + ) + ) + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] COMPLETE: {total_published} records published", + short_id=short_id, + total_published=updated.published_count, + ingest_run_srn=event.ingest_run_srn, + ) + + +async def _get_passed_records( + ingester_records: list[IngesterRecord], + batch_dir: str, + hooks: list[str], + feature_storage: FeatureStoragePort, +) -> list[IngesterRecord]: + """Determine which records passed ALL hooks via the storage port.""" + if not hooks: + return ingester_records + + passed_ids: set[str] | None = None + + for hook_name in hooks: + outcomes = await feature_storage.read_batch_outcomes(batch_dir, hook_name) + if not outcomes: + return [] + from osa.domain.validation.model.batch_outcome import OutcomeStatus + + hook_passed = {rid for rid, o in outcomes.items() if o.status == OutcomeStatus.PASSED} + + if passed_ids is None: + passed_ids = hook_passed + else: + passed_ids &= hook_passed + + if not passed_ids: + return [] + + return [r for r in ingester_records if r.source_id in passed_ids] diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py new file mode 100644 index 0000000..4681962 --- /dev/null +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -0,0 +1,114 @@ +"""RunHooks — runs hook containers on an ingester batch.""" + +from pathlib import Path +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import HookBatchCompleted, IngesterBatchReady +from osa.domain.ingest.model.ingester_record import IngesterRecord +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.port.hook_runner import HookInputs +from osa.domain.validation.service.hook import HookService +from osa.infrastructure.logging import get_logger +from osa.infrastructure.storage.layout import StorageLayout + +log = get_logger(__name__) + + +class RunHooks(EventHandler[IngesterBatchReady]): + """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" + + __claim_timeout__ = 3600.0 + + ingest_repo: IngestRunRepository + convention_service: ConventionService + hook_service: HookService + outbox: Outbox + layout: StorageLayout + + async def handle(self, event: IngesterBatchReady) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(ingest_run.convention_srn) + ) + + # Read records from batch ingester dir + ingester_dir = self.layout.ingest_batch_ingester_dir( + event.ingest_run_srn, event.batch_index + ) + records = IngesterRecord.from_jsonl(ingester_dir / "records.jsonl") + + if not records: + log.warn( + "ingest batch {batch_index}: no records to process", + batch_index=event.batch_index, + ingest_run_srn=event.ingest_run_srn, + ) + + # Build files_dirs from ingester files + files_base = ingester_dir / "files" + files_dirs: dict[str, Path] = {} + if files_base.exists(): + for record in records: + record_files = files_base / record.source_id + if record_files.exists(): + files_dirs[record.source_id] = record_files + + # Convert to HookInputs with size hints and file dirs + inputs = HookInputs( + records=[ + HookRecord( + id=r.source_id, + metadata=r.metadata, + size_hint_mb=r.total_file_mb, + ) + for r in records + ], + run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", + files_dirs=files_dirs, + ) + + # Build work_dirs for each hook + work_dirs: dict[str, Path] = {} + for hook in convention.hooks: + hook_dir = self.layout.ingest_batch_hook_dir( + event.ingest_run_srn, event.batch_index, hook.name + ) + hook_dir.mkdir(parents=True, exist_ok=True) + work_dirs[hook.name] = hook_dir + + # Run all hooks via HookService + results = await self.hook_service.run_hooks_for_batch( + hooks=convention.hooks, + inputs=inputs, + work_dirs=work_dirs, + ) + + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + for result in results: + log.info( + "[{short_id}] batch {batch_index} hook={hook_name}: {status} in {duration:.1f}s", + short_id=short_id, + batch_index=event.batch_index, + hook_name=result.hook_name, + status=result.status.value, + duration=result.duration_seconds, + ingest_run_srn=event.ingest_run_srn, + ) + + # Emit HookBatchCompleted + await self.outbox.append( + HookBatchCompleted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + batch_index=event.batch_index, + ) + ) diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py new file mode 100644 index 0000000..4952eda --- /dev/null +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -0,0 +1,139 @@ +"""RunIngester — runs ingester container on IngestStarted or continuation.""" + +import json +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import IngestStarted, IngesterBatchReady +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterRunner +from osa.infrastructure.logging import get_logger +from osa.infrastructure.storage.layout import StorageLayout + +log = get_logger(__name__) + + +class RunIngester(EventHandler[IngestStarted]): + """Runs ingester container and emits IngesterBatchReady per batch.""" + + __claim_timeout__ = 3600.0 + + ingest_repo: IngestRunRepository + convention_service: ConventionService + ingester_runner: IngesterRunner + outbox: Outbox + layout: StorageLayout + + async def handle(self, event: IngestStarted) -> None: + """Run ingester for the given ingest run and emit IngesterBatchReady. + + TODO: move this log into a service method. + """ + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + if ingest_run.status == IngestStatus.PENDING: + ingest_run.mark_running() + await self.ingest_repo.save(ingest_run) + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(event.convention_srn) + ) + if convention.ingester is None: + raise NotFoundError(f"No ingester for convention {event.convention_srn}") + + batch_index = ingest_run.batches_ingested + + batch_dir = self.layout.ingest_batch_ingester_dir(event.ingest_run_srn, batch_index) + batch_dir.mkdir(parents=True, exist_ok=True) + + session_file = self.layout.ingest_session_file(event.ingest_run_srn) + session = None + if session_file.exists(): + session = json.loads(session_file.read_text()) + + effective_batch_limit = ingest_run.batch_size + if ingest_run.limit is not None: + ingested_so_far = ingest_run.batches_ingested * ingest_run.batch_size + remaining = ingest_run.limit - ingested_so_far + if remaining <= 0: + log.warn( + "Ignoring redelivered IngestStarted — limit already met (batches_ingested={batches_ingested}, limit={limit})", + batches_ingested=ingest_run.batches_ingested, + limit=ingest_run.limit, + ingest_run_srn=event.ingest_run_srn, + ) + return + effective_batch_limit = min(ingest_run.batch_size, remaining) + + inputs = IngesterInputs( + convention_srn=convention.srn, + config=convention.ingester.config, + limit=effective_batch_limit, + session=session, + ) + files_dir = batch_dir / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + output = await self.ingester_runner.run( + ingester=convention.ingester, + inputs=inputs, + files_dir=files_dir, + work_dir=batch_dir, + ) + + records_file = batch_dir / "records.jsonl" + with records_file.open("w") as f: + for record in output.records: + f.write(json.dumps(record) + "\n") + + if output.session: + session_file.parent.mkdir(parents=True, exist_ok=True) + session_file.write_text(json.dumps(output.session)) + + has_more = output.session is not None and len(output.records) > 0 + + if has_more and ingest_run.limit is not None: + total_sourced = (ingest_run.batches_ingested + 1) * ingest_run.batch_size + if total_sourced >= ingest_run.limit: + has_more = False + + await self.ingest_repo.increment_batches_ingested( + event.ingest_run_srn, + set_ingestion_finished=not has_more, + ) + + await self.outbox.append( + IngesterBatchReady( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + batch_index=batch_index, + has_more=has_more, + ) + ) + + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] batch {batch_index}: pulled {record_count} records (has_more={has_more})", + short_id=short_id, + batch_index=batch_index, + record_count=len(output.records), + has_more=has_more, + ingest_run_srn=event.ingest_run_srn, + ) + + if has_more: + await self.outbox.append( + IngestStarted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + convention_srn=event.convention_srn, + batch_size=ingest_run.batch_size, + ) + ) diff --git a/server/osa/domain/ingest/model/__init__.py b/server/osa/domain/ingest/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/model/ingest_run.py b/server/osa/domain/ingest/model/ingest_run.py new file mode 100644 index 0000000..c95f604 --- /dev/null +++ b/server/osa/domain/ingest/model/ingest_run.py @@ -0,0 +1,86 @@ +"""IngestRun aggregate — lean summary tracking a bulk ingestion execution.""" + +from datetime import datetime +from enum import StrEnum + +from osa.domain.shared.error import InvalidStateError +from osa.domain.shared.model.aggregate import Aggregate + + +class IngestStatus(StrEnum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +_VALID_TRANSITIONS: dict[IngestStatus, set[IngestStatus]] = { + IngestStatus.PENDING: {IngestStatus.RUNNING, IngestStatus.FAILED}, + IngestStatus.RUNNING: {IngestStatus.COMPLETED, IngestStatus.FAILED}, + IngestStatus.COMPLETED: set(), + IngestStatus.FAILED: set(), +} + + +class IngestRun(Aggregate): + """Lean summary aggregate tracking a bulk ingestion execution. + + No per-record data — batch output directories on disk are the audit trail. + Counter updates use atomic SQL increments in the repository. + """ + + srn: str + convention_srn: str + status: IngestStatus = IngestStatus.PENDING + ingestion_finished: bool = False + batches_ingested: int = 0 + batches_completed: int = 0 + published_count: int = 0 + batch_size: int = 1000 + limit: int | None = None # Max total records (None = unlimited) + started_at: datetime + completed_at: datetime | None = None + + def transition_to(self, new_status: IngestStatus) -> None: + """Transition to a new status, enforcing valid transitions.""" + if new_status not in _VALID_TRANSITIONS[self.status]: + raise InvalidStateError(f"Cannot transition from {self.status} to {new_status}") + self.status = new_status + + def mark_running(self) -> None: + self.transition_to(IngestStatus.RUNNING) + + def mark_failed(self, completed_at: datetime) -> None: + self.transition_to(IngestStatus.FAILED) + self.completed_at = completed_at + + def mark_ingestion_finished(self) -> None: + self.ingestion_finished = True + + def increment_batches_ingested(self) -> None: + self.batches_ingested += 1 + + def record_batch_completed(self, published_count: int) -> None: + """Record a completed batch with its published count. + + In production, counter updates use atomic SQL increments — + this method is for in-memory aggregate state only. + """ + self.batches_completed += 1 + self.published_count += published_count + + @property + def is_complete(self) -> bool: + """Check the completion condition: all sourced batches are completed.""" + return self.ingestion_finished and self.batches_ingested == self.batches_completed + + def check_completion(self, completed_at: datetime) -> bool: + """Check completion condition and transition if met. + + Returns True if the ingest run is now complete. + """ + if self.is_complete and self.status == IngestStatus.RUNNING: + self.transition_to(IngestStatus.COMPLETED) + self.completed_at = completed_at + return True + return False diff --git a/server/osa/domain/ingest/model/ingester_record.py b/server/osa/domain/ingest/model/ingester_record.py new file mode 100644 index 0000000..ebb7512 --- /dev/null +++ b/server/osa/domain/ingest/model/ingester_record.py @@ -0,0 +1,60 @@ +"""IngesterRecord — typed representation of a record from an ingester container.""" + +import json +import logging +from pathlib import Path +from typing import Any + +from osa.domain.shared.model.value import ValueObject + +logger = logging.getLogger(__name__) + + +class IngesterFileRef(ValueObject): + """A reference to a file produced by an ingester container.""" + + name: str + relative_path: str + size_mb: float + + +class IngesterRecord(ValueObject): + """A record produced by an ingester container, parsed from records.jsonl. + + Replaces raw dicts with typed fields so downstream handlers + don't need fragile `.get("source_id", .get("id", ""))` patterns. + """ + + source_id: str + metadata: dict[str, Any] + files: list[IngesterFileRef] = [] + + @property + def total_file_mb(self) -> float: + """Sum of all file sizes in megabytes.""" + return sum(f.size_mb for f in self.files) + + @classmethod + def from_jsonl(cls, path: Path) -> list["IngesterRecord"]: + """Parse records.jsonl into typed IngesterRecord objects.""" + records: list[IngesterRecord] = [] + if not path.exists(): + return records + for line in path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + files_raw = data.get("files", []) + files = [IngesterFileRef.model_validate(f) for f in files_raw] + records.append( + IngesterRecord( + source_id=data.get("source_id", data.get("id", "")), + metadata=data.get("metadata", {}), + files=files, + ) + ) + except (json.JSONDecodeError, ValueError): + logger.warning("Skipping malformed ingester record line") + return records diff --git a/server/osa/domain/ingest/port/__init__.py b/server/osa/domain/ingest/port/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/port/repository.py b/server/osa/domain/ingest/port/repository.py new file mode 100644 index 0000000..f5839d6 --- /dev/null +++ b/server/osa/domain/ingest/port/repository.py @@ -0,0 +1,50 @@ +"""IngestRunRepository port — persistence interface for ingest runs.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.ingest.model.ingest_run import IngestRun +from osa.domain.shared.port import Port + + +class IngestRunRepository(Port, Protocol): + """Persistence interface for IngestRun aggregates. + + Counter updates (batches_completed, published_count) use atomic SQL + increments in the concrete implementation to avoid lost updates under + concurrent PublishBatch workers. + """ + + @abstractmethod + async def save(self, ingest_run: IngestRun) -> None: + """Persist an ingest run (insert or update).""" + ... + + @abstractmethod + async def get(self, srn: str) -> IngestRun | None: + """Get an ingest run by SRN.""" + ... + + @abstractmethod + async def get_running_for_convention(self, convention_srn: str) -> IngestRun | None: + """Get a running ingest run for a convention, if any.""" + ... + + @abstractmethod + async def increment_batches_ingested( + self, srn: str, *, set_ingestion_finished: bool = False + ) -> IngestRun: + """Atomically increment batches_ingested and optionally set ingestion_finished. + + Returns the updated IngestRun with DB-authoritative counter values. + """ + ... + + @abstractmethod + async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + """Atomically increment batches_completed and published_count. + + Returns the updated IngestRun with DB-authoritative counter values + for completion condition checking. + """ + ... diff --git a/server/osa/domain/ingest/service/__init__.py b/server/osa/domain/ingest/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/service/ingest.py b/server/osa/domain/ingest/service/ingest.py new file mode 100644 index 0000000..35c338a --- /dev/null +++ b/server/osa/domain/ingest/service/ingest.py @@ -0,0 +1,87 @@ +"""IngestService — orchestrates ingest lifecycle.""" + +from datetime import UTC, datetime +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import IngestStarted +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.event import EventId +from osa.domain.shared.model.srn import ConventionSRN, Domain +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.service import Service +from osa.infrastructure.logging import get_logger + +log = get_logger(__name__) + + +class IngestService(Service): + """Orchestrates ingest run creation and lifecycle.""" + + ingest_repo: IngestRunRepository + convention_service: ConventionService + outbox: Outbox + node_domain: Domain + + async def start_ingest( + self, + convention_srn: str, + batch_size: int = 1000, + limit: int | None = None, + ) -> IngestRun: + """Create an ingest run for a convention. + + Validates: + - Convention exists + - Convention has an ingester configured + - No ingest is already running for this convention + """ + parsed_srn = ConventionSRN.parse(convention_srn) + convention = await self.convention_service.get_convention(parsed_srn) + + if convention.ingester is None: + raise NotFoundError( + f"No ingester configured for convention {convention_srn}", + code="no_ingester_configured", + ) + + existing = await self.ingest_repo.get_running_for_convention(convention_srn) + if existing is not None: + raise ConflictError( + f"Ingest already running for convention {convention_srn}", + code="ingest_already_running", + ) + + srn = f"urn:osa:{self.node_domain.root}:ing:{uuid4()}" + now = datetime.now(UTC) + + ingest_run = IngestRun( + srn=srn, + convention_srn=convention_srn, + status=IngestStatus.PENDING, + batch_size=batch_size, + limit=limit, + started_at=now, + ) + + await self.ingest_repo.save(ingest_run) + + await self.outbox.append( + IngestStarted( + id=EventId(uuid4()), + ingest_run_srn=srn, + convention_srn=convention_srn, + batch_size=batch_size, + ) + ) + + log.info( + "ingest started for {convention_srn}", + ingest_run_srn=srn, + convention_srn=convention_srn, + batch_size=batch_size, + limit=limit, + ) + return ingest_run diff --git a/server/osa/domain/record/port/repository.py b/server/osa/domain/record/port/repository.py index 34a1724..345d6e2 100644 --- a/server/osa/domain/record/port/repository.py +++ b/server/osa/domain/record/port/repository.py @@ -12,6 +12,11 @@ class RecordRepository(Port, Protocol): @abstractmethod async def save(self, record: Record) -> None: ... + @abstractmethod + async def save_many(self, records: list[Record]) -> list[Record]: + """Multi-row INSERT with ON CONFLICT DO NOTHING. Returns inserted records.""" + ... + @abstractmethod async def get(self, srn: RecordSRN) -> Record | None: ... diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index d6d3cc4..e2409ce 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -49,6 +49,36 @@ async def get(self, srn: RecordSRN) -> Record: raise NotFoundError(f"Record not found: {srn}") return record + async def bulk_publish(self, drafts: list[RecordDraft]) -> list[Record]: + """Bulk-publish records from an ingest batch. + + Uses save_many() for multi-row INSERT with ON CONFLICT DO NOTHING. + Does NOT emit per-record RecordPublished events — the caller emits + a single IngestBatchPublished event instead (AD-3). + """ + if not drafts: + return [] + + records: list[Record] = [] + for draft in drafts: + record_srn = RecordSRN( + domain=self.node_domain, + id=LocalId(str(uuid4())), + version=RecordVersion(1), + ) + records.append( + Record( + srn=record_srn, + source=draft.source, + convention_srn=draft.convention_srn, + metadata=draft.metadata, + published_at=datetime.now(UTC), + ) + ) + + published = await self.record_repo.save_many(records) + return published + async def publish_record(self, draft: RecordDraft) -> Record: """Create and persist a Record from a draft.""" logger.info(f"Creating record from {draft.source.type} source: {draft.source.id}") diff --git a/server/osa/domain/shared/event.py b/server/osa/domain/shared/event.py index 7c3d565..0fc19b3 100644 --- a/server/osa/domain/shared/event.py +++ b/server/osa/domain/shared/event.py @@ -209,14 +209,11 @@ class EventHandler(Generic[E], metaclass=_EventHandlerMeta): __claim_timeout__: Seconds before claim considered stale (default: 300.0) Example (single event): - class TriggerInitialSourceRun(EventHandler[ServerStarted]): - _config: Config - _outbox: Outbox + class HandleRecordPublished(EventHandler[RecordPublished]): + _service: IndexingService - async def handle(self, event: ServerStarted) -> None: - for source in self._config.sources: - if source.initial_run and source.initial_run.enabled: - await self._outbox.append(SourceRequested(...)) + async def handle(self, event: RecordPublished) -> None: + await self._service.index(event.record_srn) Example (batch processing): class VectorIndexHandler(EventHandler[IndexRecord]): @@ -236,6 +233,7 @@ async def handle_batch(self, events: list[IndexRecord]) -> None: __poll_interval__: ClassVar[float] = 0.5 __max_retries__: ClassVar[int] = 3 __claim_timeout__: ClassVar[float] = 300.0 + __concurrency__: ClassVar[int] = 1 async def handle(self, event: E) -> None: """Handle a single event. Override for single-event processing. diff --git a/server/osa/domain/shared/model/hook.py b/server/osa/domain/shared/model/hook.py index 7fe5e29..346ba34 100644 --- a/server/osa/domain/shared/model/hook.py +++ b/server/osa/domain/shared/model/hook.py @@ -5,6 +5,7 @@ (NextflowConfig, TimeSeriesFeatureSpec, …) slot in without touching existing code. """ +import re from typing import Annotated, Any, Literal from pydantic import Field @@ -15,6 +16,45 @@ # Safe for use as PG identifiers, file path components, and env var values. PgIdentifier = Annotated[str, Field(pattern=r"^[a-z][a-z0-9_]{0,62}$")] +_MEMORY_RE = re.compile(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$") + +_GIB = 1024 * 1024 * 1024 +_MIB = 1024 * 1024 +_KIB = 1024 + + +def parse_memory(memory: str) -> int: + """Parse memory string like '2g' or '512m' to bytes.""" + match = _MEMORY_RE.match(memory.strip().lower()) + if not match: + raise ValueError(f"Invalid memory format: {memory}") + + amount = float(match.group(1)) + unit = match.group(2) + + match unit: + case "g": + return int(amount * _GIB) + case "m": + return int(amount * _MIB) + case "k": + return int(amount * _KIB) + case None: + return int(amount) + case _: + raise ValueError(f"Unknown memory unit: {unit}") + + +def _format_memory(byte_count: int) -> str: + """Format bytes to a compact memory string (e.g. '2g', '1536m').""" + if byte_count % _GIB == 0: + return f"{byte_count // _GIB}g" + if byte_count % _MIB == 0: + return f"{byte_count // _MIB}m" + if byte_count % _KIB == 0: + return f"{byte_count // _KIB}k" + return str(byte_count) + class ColumnDef(ValueObject): """Definition of a single column in a feature table.""" @@ -32,7 +72,7 @@ class OciLimits(ValueObject): """Resource limits for OCI hook execution.""" timeout_seconds: int = 300 - memory: str = "512m" + memory: str = "1g" cpu: str = "0.5" @@ -78,3 +118,15 @@ class HookDefinition(ValueObject): name: PgIdentifier runtime: Annotated[OciConfig, Field(discriminator="type")] feature: Annotated[TableFeatureSpec, Field(discriminator="kind")] + + def with_memory(self, memory: str) -> "HookDefinition": + """Return a copy with a different memory limit.""" + new_limits = self.runtime.limits.model_copy(update={"memory": memory}) + new_runtime = self.runtime.model_copy(update={"limits": new_limits}) + return self.model_copy(update={"runtime": new_runtime}) + + def with_doubled_memory(self) -> "HookDefinition": + """Return a copy with 2x the current memory limit.""" + current_bytes = parse_memory(self.runtime.limits.memory) + doubled = _format_memory(current_bytes * 2) + return self.with_memory(doubled) diff --git a/server/osa/domain/shared/model/source.py b/server/osa/domain/shared/model/source.py index e61fd8e..9d4b1ef 100644 --- a/server/osa/domain/shared/model/source.py +++ b/server/osa/domain/shared/model/source.py @@ -1,4 +1,4 @@ -"""Shared source domain models used across deposition and source domains.""" +"""Shared source domain models used across deposition and ingest domains.""" from typing import Annotated, Any, Literal, Union @@ -7,23 +7,23 @@ from osa.domain.shared.model.value import ValueObject -class SourceLimits(ValueObject): - """Resource limits for source container execution.""" +class IngesterLimits(ValueObject): + """Resource limits for ingester container execution.""" timeout_seconds: int = 3600 - memory: str = "512m" + memory: str = "1g" cpu: str = "0.25" -class SourceScheduleConfig(ValueObject): - """Cron schedule for periodic source runs.""" +class IngesterScheduleConfig(ValueObject): + """Cron schedule for periodic ingester runs.""" cron: str limit: int | None = None class InitialRunConfig(ValueObject): - """Configuration for the first source run on server startup.""" + """Configuration for the first ingester run on server startup.""" limit: int | None = None @@ -51,11 +51,11 @@ class DepositionSource(_RecordSourceBase): type: Literal["deposition"] = "deposition" -class HarvestSource(_RecordSourceBase): - """Record originated from an automated harvest run.""" +class IngestSource(_RecordSourceBase): + """Record originated from an automated ingest run.""" - type: Literal["harvest"] = "harvest" - harvest_run_srn: str + type: Literal["ingest"] = "ingest" + ingest_run_srn: str upstream_source: str @@ -68,21 +68,21 @@ def _record_source_discriminator(v: Any) -> str: RecordSource = Annotated[ Union[ Annotated[DepositionSource, Tag("deposition")], - Annotated[HarvestSource, Tag("harvest")], + Annotated[IngestSource, Tag("ingest")], ], Discriminator(_record_source_discriminator), ] -# ── Source runner definitions ── +# ── Ingester runner definitions ── -class SourceDefinition(ValueObject): - """Complete specification for a source: image reference + config + limits.""" +class IngesterDefinition(ValueObject): + """Complete specification for an ingester: image reference + config + limits.""" image: str digest: str config: dict[str, Any] | None = None - limits: SourceLimits = Field(default_factory=SourceLimits) - schedule: SourceScheduleConfig | None = None + limits: IngesterLimits = Field(default_factory=IngesterLimits) + schedule: IngesterScheduleConfig | None = None initial_run: InitialRunConfig | None = None diff --git a/server/osa/domain/source/port/source_runner.py b/server/osa/domain/shared/port/ingester_runner.py similarity index 51% rename from server/osa/domain/source/port/source_runner.py rename to server/osa/domain/shared/port/ingester_runner.py index 44ad49f..c4bb7a0 100644 --- a/server/osa/domain/source/port/source_runner.py +++ b/server/osa/domain/shared/port/ingester_runner.py @@ -1,4 +1,8 @@ -"""SourceRunner port — interface for executing source containers.""" +"""IngesterRunner port — interface for executing ingester containers. + +Relocated from domain/source/ to shared/port/ since both the ingest +domain and infrastructure runners depend on this contract. +""" from __future__ import annotations @@ -7,13 +11,13 @@ from pathlib import Path from typing import Any, Protocol -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN @dataclass(frozen=True) -class SourceInputs: - """Inputs for a source container run.""" +class IngesterInputs: + """Inputs for an ingester container run.""" convention_srn: ConventionSRN config: dict[str, Any] | None = None @@ -24,21 +28,21 @@ class SourceInputs: @dataclass(frozen=True) -class SourceOutput: - """Output from a source container run.""" +class IngesterOutput: + """Output from an ingester container run.""" records: list[dict[str, Any]] # Parsed from records.jsonl session: dict[str, Any] | None # From session.json (continuation) - files_dir: Path # Where source wrote data files + files_dir: Path # Where ingester wrote data files -class SourceRunner(Protocol): - """Protocol for executing source containers.""" +class IngesterRunner(Protocol): + """Protocol for executing ingester containers.""" async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + ingester: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: ... + ) -> IngesterOutput: ... diff --git a/server/osa/domain/source/__init__.py b/server/osa/domain/source/__init__.py deleted file mode 100644 index a001339..0000000 --- a/server/osa/domain/source/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source domain - configuration and orchestration for data sources.""" diff --git a/server/osa/domain/source/event/__init__.py b/server/osa/domain/source/event/__init__.py deleted file mode 100644 index 9d82d2f..0000000 --- a/server/osa/domain/source/event/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Source domain events.""" - -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted - -__all__ = ["SourceRecordReady", "SourceRequested", "SourceRunCompleted"] diff --git a/server/osa/domain/source/event/source_record_ready.py b/server/osa/domain/source/event/source_record_ready.py deleted file mode 100644 index 5d19b0c..0000000 --- a/server/osa/domain/source/event/source_record_ready.py +++ /dev/null @@ -1,21 +0,0 @@ -"""SourceRecordReady event — emitted per record produced by a source container.""" - -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRecordReady(Event): - """Emitted for each record produced by a source run. - - Replaces direct DepositionService calls in SourceService. - Consumed by CreateDepositionFromSource in the deposition domain. - """ - - id: EventId - convention_srn: ConventionSRN - metadata: dict[str, Any] - file_paths: list[str] - source_id: str - staging_dir: str diff --git a/server/osa/domain/source/event/source_requested.py b/server/osa/domain/source/event/source_requested.py deleted file mode 100644 index dc213dd..0000000 --- a/server/osa/domain/source/event/source_requested.py +++ /dev/null @@ -1,28 +0,0 @@ -"""SourceRequested event - triggers pulling from a data source.""" - -from datetime import datetime -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRequested(Event): - """Emitted when pulling should start for a source. - - The convention SRN identifies which convention (and its SourceDefinition) - to run. The server loads the convention to get the source image/config. - - For chunked processing: - - `offset`: Starting position for this chunk (0 for first chunk) - - `chunk_size`: Number of records to process per chunk - - `session`: Opaque pagination state for efficient continuation - """ - - id: EventId - convention_srn: ConventionSRN - since: datetime | None = None - limit: int | None = None - offset: int = 0 - chunk_size: int = 1000 - session: dict[str, Any] | None = None diff --git a/server/osa/domain/source/event/source_run_completed.py b/server/osa/domain/source/event/source_run_completed.py deleted file mode 100644 index 0b93d93..0000000 --- a/server/osa/domain/source/event/source_run_completed.py +++ /dev/null @@ -1,17 +0,0 @@ -"""SourceRunCompleted event - emitted after a source run finishes.""" - -from datetime import datetime - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRunCompleted(Event): - """Emitted after a source run completes.""" - - id: EventId - convention_srn: ConventionSRN - started_at: datetime - completed_at: datetime - record_count: int - is_final_chunk: bool = True diff --git a/server/osa/domain/source/handler/__init__.py b/server/osa/domain/source/handler/__init__.py deleted file mode 100644 index 3cc4466..0000000 --- a/server/osa/domain/source/handler/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Source domain event handlers.""" - -from osa.domain.source.handler.pull_from_source import PullFromSource -from osa.domain.source.handler.trigger_initial_source_run import TriggerInitialSourceRun - -__all__ = ["PullFromSource", "TriggerInitialSourceRun"] diff --git a/server/osa/domain/source/handler/pull_from_source.py b/server/osa/domain/source/handler/pull_from_source.py deleted file mode 100644 index f611c57..0000000 --- a/server/osa/domain/source/handler/pull_from_source.py +++ /dev/null @@ -1,34 +0,0 @@ -"""PullFromSource - handles SourceRequested events.""" - -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.shared.error import NotFoundError -from osa.domain.shared.event import EventHandler -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.service import SourceService - - -class PullFromSource(EventHandler[SourceRequested]): - """Runs a source container and emits per-record events. - - Looks up the convention to get the SourceDefinition, then - delegates to SourceService for container execution. - """ - - service: SourceService - convention_service: ConventionService - - async def handle(self, event: SourceRequested) -> None: - """Look up convention, extract source definition, and run source.""" - convention = await self.convention_service.get_convention(event.convention_srn) - if convention.source is None: - raise NotFoundError(f"No source defined for convention {event.convention_srn}") - - await self.service.run_source( - convention_srn=event.convention_srn, - source=convention.source, - since=event.since, - limit=event.limit, - offset=event.offset, - chunk_size=event.chunk_size, - session=event.session, - ) diff --git a/server/osa/domain/source/handler/trigger_initial_source_run.py b/server/osa/domain/source/handler/trigger_initial_source_run.py deleted file mode 100644 index c3a84bd..0000000 --- a/server/osa/domain/source/handler/trigger_initial_source_run.py +++ /dev/null @@ -1,39 +0,0 @@ -"""TriggerInitialSourceRun - triggers source pull when feature tables are ready.""" - -import logging -from uuid import uuid4 - -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.feature.event.convention_ready import ConventionReady -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested - -logger = logging.getLogger(__name__) - - -class TriggerInitialSourceRun(EventHandler[ConventionReady]): - """Emits SourceRequested when a convention with initial_run is ready. - - Part of the convention initialization chain: - ConventionRegistered → CreateFeatureTables → ConventionReady → TriggerInitialSourceRun - """ - - convention_service: ConventionService - outbox: Outbox - - async def handle(self, event: ConventionReady) -> None: - conv = await self.convention_service.get_convention(event.convention_srn) - - if conv.source is None or conv.source.initial_run is None: - return - - logger.info("Source deploy trigger: convention=%s", conv.srn) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=conv.srn, - limit=conv.source.initial_run.limit, - ) - ) diff --git a/server/osa/domain/source/handler/trigger_source_on_deploy.py b/server/osa/domain/source/handler/trigger_source_on_deploy.py deleted file mode 100644 index 534a287..0000000 --- a/server/osa/domain/source/handler/trigger_source_on_deploy.py +++ /dev/null @@ -1,35 +0,0 @@ -"""TriggerSourceOnDeploy - triggers source pull when a convention with a source is deployed.""" - -import logging -from uuid import uuid4 - -from osa.domain.deposition.event.convention_registered import ConventionRegistered -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested - -logger = logging.getLogger(__name__) - - -class TriggerSourceOnDeploy(EventHandler[ConventionRegistered]): - """Emits SourceRequested when a convention with initial_run is deployed.""" - - convention_service: ConventionService - outbox: Outbox - - async def handle(self, event: ConventionRegistered) -> None: - conv = await self.convention_service.get_convention(event.convention_srn) - - if conv.source is None or conv.source.initial_run is None: - return - - logger.info("Source deploy trigger: convention=%s", conv.srn) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=conv.srn, - limit=conv.source.initial_run.limit, - ) - ) diff --git a/server/osa/domain/source/model/__init__.py b/server/osa/domain/source/model/__init__.py deleted file mode 100644 index 7dc2082..0000000 --- a/server/osa/domain/source/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source domain models.""" diff --git a/server/osa/domain/source/port/storage.py b/server/osa/domain/source/port/storage.py deleted file mode 100644 index 821119c..0000000 --- a/server/osa/domain/source/port/storage.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Storage port scoped to the source domain.""" - -from abc import abstractmethod -from pathlib import Path -from typing import Protocol - -from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.shared.port import Port - - -class SourceStoragePort(Port, Protocol): - """File storage operations used by the source domain.""" - - @abstractmethod - def get_source_staging_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: - """Staging dir for source-ingested files, isolated per run.""" - ... - - @abstractmethod - def get_source_output_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: - """Output dir for a source run (records.jsonl, session.json).""" - ... - - @abstractmethod - async def move_source_files_to_deposition( - self, - staging_dir: Path, - source_id: str, - deposition_srn: DepositionSRN, - ) -> None: - """Move source staging files into the deposition's canonical file location.""" - ... diff --git a/server/osa/domain/source/schedule/__init__.py b/server/osa/domain/source/schedule/__init__.py deleted file mode 100644 index d566437..0000000 --- a/server/osa/domain/source/schedule/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Source scheduled tasks.""" - -from osa.domain.source.schedule.source_schedule import SourceSchedule - -__all__ = ["SourceSchedule"] diff --git a/server/osa/domain/source/schedule/source_schedule.py b/server/osa/domain/source/schedule/source_schedule.py deleted file mode 100644 index 7c9c8b2..0000000 --- a/server/osa/domain/source/schedule/source_schedule.py +++ /dev/null @@ -1,60 +0,0 @@ -"""SourceSchedule - scheduled task that emits SourceRequested events.""" - -import logging -from dataclasses import dataclass -from typing import Any -from uuid import uuid4 - -from osa.domain.shared.event import EventId, Schedule -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted - -logger = logging.getLogger(__name__) - - -@dataclass -class SourceSchedule(Schedule): - """Scheduled task that emits SourceRequested events. - - Looks up the last completed source run to determine the `since` timestamp, - then emits a SourceRequested event to trigger a new pull. - """ - - outbox: Outbox - - async def run(self, **params: Any) -> None: - """Emit a SourceRequested event for the given convention. - - Params: - convention: Convention SRN string - limit: Optional limit on records to fetch - """ - convention_srn = ConventionSRN.parse(params["convention"]) - limit: int | None = params.get("limit") - - # Look up last completed run for this convention - last_run = await self.outbox.find_latest_where( - SourceRunCompleted, convention_srn=str(convention_srn) - ) - - since = None - if last_run is not None: - since = last_run.completed_at - - logger.info( - "Scheduled source run: convention=%s (since=%s, limit=%s)", - convention_srn, - since, - limit, - ) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=convention_srn, - since=since, - limit=limit, - ) - ) diff --git a/server/osa/domain/source/service/__init__.py b/server/osa/domain/source/service/__init__.py deleted file mode 100644 index ecefd99..0000000 --- a/server/osa/domain/source/service/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Source service module.""" - -from osa.domain.source.service.source import SourceResult, SourceService - -__all__ = ["SourceService", "SourceResult"] diff --git a/server/osa/domain/source/service/source.py b/server/osa/domain/source/service/source.py deleted file mode 100644 index 03d2040..0000000 --- a/server/osa/domain/source/service/source.py +++ /dev/null @@ -1,154 +0,0 @@ -"""SourceService - orchestrates running OCI source containers.""" - -import logging -from dataclasses import dataclass -from datetime import UTC, datetime -from typing import Any -from uuid import uuid4 - -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.shared.outbox import Outbox -from osa.domain.shared.service import Service -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceInputs, SourceRunner -from osa.domain.source.port.storage import SourceStoragePort - -logger = logging.getLogger(__name__) - - -@dataclass -class SourceResult: - """Result of a source run.""" - - convention_srn: ConventionSRN - record_count: int - started_at: datetime - completed_at: datetime - - -class SourceService(Service): - """Orchestrates running source containers. - - For each record produced by the source container, emits a - SourceRecordReady event. The deposition domain handles creating - depositions from these events. - """ - - source_runner: SourceRunner - source_storage: SourceStoragePort - outbox: Outbox - - async def run_source( - self, - convention_srn: ConventionSRN, - source: SourceDefinition, - since: datetime | None = None, - limit: int | None = None, - offset: int = 0, - chunk_size: int = 1000, - session: dict[str, Any] | None = None, - ) -> SourceResult: - """Run a source container and emit events for each produced record.""" - started_at = datetime.now(UTC) - run_id = str(uuid4())[:12] - - logger.info( - "Starting source run for %s (run=%s, since=%s, limit=%s, offset=%s)", - convention_srn, - run_id, - since, - limit, - offset, - ) - - # Prepare dirs - staging_dir = self.source_storage.get_source_staging_dir(convention_srn, run_id) - work_dir = self.source_storage.get_source_output_dir(convention_srn, run_id) - - # Build inputs - inputs = SourceInputs( - convention_srn=convention_srn, - config=source.config, - since=since, - limit=limit, - offset=offset, - session=session, - ) - - # Run container - output = await self.source_runner.run( - source=source, - inputs=inputs, - files_dir=staging_dir, - work_dir=work_dir, - ) - - # Emit per-record events - count = 0 - for record_data in output.records: - source_id = record_data.get("source_id", "") - metadata = record_data.get("metadata", {}) - file_paths = record_data.get("file_paths", []) - - await self.outbox.append( - SourceRecordReady( - id=EventId(uuid4()), - convention_srn=convention_srn, - metadata=metadata, - file_paths=file_paths, - source_id=source_id, - staging_dir=str(staging_dir), - ) - ) - count += 1 - - if count % 100 == 0: - logger.info(" Emitted %d SourceRecordReady events so far...", count) - - completed_at = datetime.now(UTC) - is_final_chunk = output.session is None or count == 0 - - logger.info( - "Source run completed: %d records (run=%s, is_final=%s)", - count, - run_id, - is_final_chunk, - ) - - # Emit continuation if session exists - if not is_final_chunk: - next_offset = offset + count - logger.info("Emitting continuation event, next_offset=%d", next_offset) - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=convention_srn, - since=since, - limit=limit, - offset=next_offset, - chunk_size=chunk_size, - session=output.session, - ) - ) - - await self.outbox.append( - SourceRunCompleted( - id=EventId(uuid4()), - convention_srn=convention_srn, - started_at=started_at, - completed_at=completed_at, - record_count=count, - is_final_chunk=is_final_chunk, - ) - ) - - return SourceResult( - convention_srn=convention_srn, - record_count=count, - started_at=started_at, - completed_at=completed_at, - ) diff --git a/server/osa/domain/validation/model/batch_outcome.py b/server/osa/domain/validation/model/batch_outcome.py new file mode 100644 index 0000000..93f8e0d --- /dev/null +++ b/server/osa/domain/validation/model/batch_outcome.py @@ -0,0 +1,31 @@ +"""Per-record outcome from a batch hook run.""" + +from enum import StrEnum +from typing import Any, NewType + +from osa.domain.shared.model.value import ValueObject + +HookRecordId = NewType("HookRecordId", str) + + +class OutcomeStatus(StrEnum): + """Outcome status for a single record in a batch hook execution.""" + + PASSED = "passed" + REJECTED = "rejected" + ERRORED = "errored" + + +class BatchRecordOutcome(ValueObject): + """Per-record outcome from a batch hook execution. + + Each record in a batch ends up in exactly one of three states: + passed (with features), rejected (with reason), or errored. + """ + + record_id: HookRecordId + status: OutcomeStatus + features: list[dict[str, Any]] = [] + reason: str | None = None + error: str | None = None + retryable: bool = False diff --git a/server/osa/domain/validation/model/hook_input.py b/server/osa/domain/validation/model/hook_input.py new file mode 100644 index 0000000..085c74c --- /dev/null +++ b/server/osa/domain/validation/model/hook_input.py @@ -0,0 +1,16 @@ +"""Value objects for hook input data.""" + +from typing import Any + +from osa.domain.shared.model.value import ValueObject + + +class HookRecord(ValueObject): + """A single record to be processed by a hook. + + Maps to one line in records.jsonl: {"id": "...", "metadata": {...}}. + """ + + id: str + metadata: dict[str, Any] + size_hint_mb: float = 0 diff --git a/server/osa/domain/validation/model/hook_result.py b/server/osa/domain/validation/model/hook_result.py index bd71fcd..ba4b959 100644 --- a/server/osa/domain/validation/model/hook_result.py +++ b/server/osa/domain/validation/model/hook_result.py @@ -11,6 +11,7 @@ class HookStatus(StrEnum): PASSED = "passed" REJECTED = "rejected" FAILED = "failed" + OOM = "oom" class ProgressEntry(ValueObject): @@ -30,3 +31,8 @@ class HookResult(ValueObject): error_message: str | None = None progress: list[ProgressEntry] = Field(default_factory=list) duration_seconds: float + + @property + def oom_killed(self) -> bool: + """Whether this hook was killed by an out-of-memory condition.""" + return self.status == HookStatus.OOM diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index 703596b..55ff484 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -1,22 +1,28 @@ """Port for executing hooks in OCI containers.""" from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Protocol, runtime_checkable from osa.domain.shared.model.hook import HookDefinition from osa.domain.shared.port import Port +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.model.hook_result import HookResult @dataclass(frozen=True) class HookInputs: - """Inputs to pass to a hook container.""" + """Inputs to pass to a hook container. - record_json: dict + Uses the unified batch contract: records is a list of HookRecord + (1 for depositions, N for ingests). + files_dirs maps record ID → directory containing that record's files. + """ + + records: list[HookRecord] run_id: str - files_dir: Path | None = None + files_dirs: dict[str, Path] = field(default_factory=dict) config: dict | None = None diff --git a/server/osa/domain/validation/port/storage.py b/server/osa/domain/validation/port/storage.py index 353578b..eb6d34e 100644 --- a/server/osa/domain/validation/port/storage.py +++ b/server/osa/domain/validation/port/storage.py @@ -6,6 +6,7 @@ from osa.domain.shared.model.srn import DepositionSRN from osa.domain.shared.port import Port +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome, HookRecordId class HookStoragePort(Port, Protocol): @@ -20,3 +21,26 @@ def get_hook_output_dir(self, deposition_srn: DepositionSRN, hook_name: str) -> def get_files_dir(self, deposition_id: DepositionSRN) -> Path: """Return the directory containing data files for a deposition.""" ... + + @abstractmethod + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + """Atomically write checkpoint JSONL to work_dir/_checkpoint.jsonl.""" + ... + + @abstractmethod + def write_batch_outcomes( + self, + work_dir: Path, + outcomes: dict[HookRecordId, BatchRecordOutcome], + ) -> None: + """Write canonical features.jsonl, rejections.jsonl, errors.jsonl.""" + ... + + @abstractmethod + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + """Read JSONL batch outputs (features/rejections/errors) for a hook.""" + ... diff --git a/server/osa/domain/validation/service/hook.py b/server/osa/domain/validation/service/hook.py new file mode 100644 index 0000000..55d2749 --- /dev/null +++ b/server/osa/domain/validation/service/hook.py @@ -0,0 +1,239 @@ +"""HookService — executes hooks with OOM retry and checkpointing. + +Handles both single-record (deposition) and multi-record (ingestion) batches. +On OOM, retries with doubled memory up to MAX_OOM_RETRIES times. +Checkpoints partial progress so crash recovery skips completed records. + +Sorting assumption: hooks process records in input order and write output +incrementally (features.jsonl line by line). Sorting by file size ascending +maximizes checkpoint progress before a potential OOM on a large record. +""" + +import json +from collections.abc import Iterable +from pathlib import Path + +from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.service import Service +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.domain.validation.port.storage import HookStoragePort +from osa.infrastructure.logging import get_logger + +log = get_logger(__name__) + +MAX_OOM_RETRIES = 3 + + +class HookService(Service): + """Executes a hook with OOM retry, checkpointing, and finalization.""" + + hook_runner: HookRunner + hook_storage: HookStoragePort + + async def run_hook( + self, + hook: HookDefinition, + inputs: HookInputs, + work_dir: Path, + ) -> HookResult: + """Run a single hook against a batch of records, retrying on OOM. + + Returns the final HookResult. On success or non-OOM failure, returns + after the first attempt. On OOM, retries with doubled memory up to + MAX_OOM_RETRIES times, checkpointing partial progress between attempts. + """ + records = inputs.records + if not records: + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=0.0, + ) + + # Load checkpoint (crash recovery) + outcomes = _load_checkpoint(work_dir) + remaining = _sort_by_size(r for r in records if r.id not in outcomes) + + if not remaining: + # All records already checkpointed + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=0.0, + ) + + current_hook = hook + total_duration = 0.0 + + for attempt in range(1 + MAX_OOM_RETRIES): + attempt_inputs = HookInputs( + records=remaining, + run_id=inputs.run_id, + files_dirs=inputs.files_dirs, + config=inputs.config, + ) + + result = await self.hook_runner.run(current_hook, attempt_inputs, work_dir) + total_duration += result.duration_seconds + + # Read any output written by this attempt + new_outcomes = _read_output_dir(work_dir) + for rid, outcome in new_outcomes.items(): + if rid not in outcomes: + outcomes[rid] = outcome + + if result.oom_killed: + # Checkpoint what we have so far + self.hook_storage.write_checkpoint(work_dir, outcomes) + + remaining = _sort_by_size(r for r in records if r.id not in outcomes) + if not remaining: + break + + if attempt < MAX_OOM_RETRIES: + current_hook = current_hook.with_doubled_memory() + log.info( + "OOM retry {attempt}/{max_retries} for hook={hook_name}, memory={memory}, remaining={remaining} records", + attempt=attempt + 1, + max_retries=MAX_OOM_RETRIES, + hook_name=hook.name, + memory=current_hook.runtime.limits.memory, + remaining=len(remaining), + ) + continue + else: + # Exhausted retries — mark remaining as errored + for r in remaining: + outcomes[HookRecordId(r.id)] = BatchRecordOutcome( + record_id=HookRecordId(r.id), + status=OutcomeStatus.ERRORED, + error=f"OOM after {MAX_OOM_RETRIES} retries (last limit: {current_hook.runtime.limits.memory})", + ) + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.OOM, + error_message=f"OOM exhausted after {MAX_OOM_RETRIES} retries", + duration_seconds=total_duration, + ) + elif result.status == HookStatus.FAILED: + # Non-OOM failure — no retry + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return result + elif result.status == HookStatus.REJECTED: + # Rejection — no retry, propagate status + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.REJECTED, + rejection_reason=result.rejection_reason, + duration_seconds=total_duration, + ) + else: + # Success (PASSED) + break + + # Finalize: write canonical output files + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + _cleanup_checkpoint(work_dir) + + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=total_duration, + ) + + async def run_hooks_for_batch( + self, + hooks: list[HookDefinition], + inputs: HookInputs, + work_dirs: dict[str, Path], + ) -> list[HookResult]: + """Run multiple hooks sequentially for a batch of records. + + work_dirs maps hook_name → output directory. + """ + results: list[HookResult] = [] + for hook in hooks: + work_dir = work_dirs[hook.name] + result = await self.run_hook(hook, inputs, work_dir) + results.append(result) + return results + + +def _sort_by_size(records: Iterable[HookRecord]) -> list[HookRecord]: + """Sort records by size_hint_mb ascending. Skip sort when all sizes are 0.""" + record_list = list(records) + if any(r.size_hint_mb > 0 for r in record_list): + return sorted(record_list, key=lambda r: r.size_hint_mb) + return record_list + + +def _load_checkpoint(work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + """Load checkpoint from _checkpoint.jsonl. Returns empty dict on missing/corrupt.""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + if not checkpoint_path.exists(): + return {} + + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + for line in checkpoint_path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + outcome = BatchRecordOutcome.model_validate(data) + outcomes[outcome.record_id] = outcome + except (json.JSONDecodeError, ValueError): + log.warn("Skipping malformed checkpoint line") + continue + + return outcomes + + +def _read_output_dir(work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + """Read hook output files (features.jsonl, rejections.jsonl, errors.jsonl).""" + output_dir = work_dir / "output" + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + + for filename, status, field_map in [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), + ]: + path = output_dir / filename + if not path.exists(): + continue + for line in path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + raw_id = data.get("id") + if not raw_id: + continue + record_id = HookRecordId(raw_id) + kwargs: dict = {"record_id": record_id, "status": status} + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) + + return outcomes + + +def _cleanup_checkpoint(work_dir: Path) -> None: + """Remove checkpoint file after successful finalization.""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + checkpoint_path.unlink(missing_ok=True) diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index 0783d9b..a2dcdee 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -18,10 +18,12 @@ RunStatus, ValidationRun, ) +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.domain.validation.port.repository import ValidationRunRepository from osa.domain.validation.port.storage import HookStoragePort +from osa.domain.validation.service.hook import HookService logger = logging.getLogger(__name__) @@ -63,24 +65,30 @@ async def run_hooks( inputs: HookInputs, hooks: list[HookDefinition], ) -> tuple[ValidationRun, list[HookResult]]: - """Execute hooks sequentially. Halt on reject/fail. + """Execute hooks sequentially with OOM retry. Halt on reject/fail/OOM. Hook outputs are written to durable cold storage under the deposition directory. Feature insertion is deferred to record publication time. + Each hook is executed via HookService which handles OOM retry with memory doubling. """ run.status = RunStatus.RUNNING run.started_at = datetime.now(timezone.utc) await self.run_repo.save(run) + hook_service = HookService( + hook_runner=self.hook_runner, + hook_storage=self.hook_storage, + ) + hook_results: list[HookResult] = [] overall_status: RunStatus = RunStatus.COMPLETED for hook in hooks: work_dir = self.hook_storage.get_hook_output_dir(deposition_srn, hook.name) - result = await self.hook_runner.run(hook, inputs, work_dir) + result = await hook_service.run_hook(hook, inputs, work_dir) hook_results.append(result) - if result.status == HookStatus.FAILED: + if result.status in (HookStatus.FAILED, HookStatus.OOM): overall_status = RunStatus.FAILED break if result.status == HookStatus.REJECTED: @@ -101,14 +109,19 @@ async def validate_deposition( metadata: dict[str, Any], hooks: list[HookDefinition], ) -> tuple[ValidationRun, list[HookResult]]: - """Full validation workflow using enriched event data.""" - record_json = {"srn": str(deposition_srn), "metadata": metadata} - run_id = f"{deposition_srn.domain.root}_{deposition_srn.id.root}" + """Full validation workflow using enriched event data. + + Uses the unified batch contract: constructs a 1-record batch for depositions. + """ + local_id = deposition_srn.id.root + record = HookRecord(id=local_id, metadata=metadata) + run_id = f"{deposition_srn.domain.root}_{local_id}" files_dir = self.hook_storage.get_files_dir(deposition_srn) + inputs = HookInputs( - record_json=record_json, + records=[record], run_id=run_id, - files_dir=files_dir, + files_dirs={local_id: files_dir} if files_dir else {}, ) run = await self.create_run(inputs=inputs) diff --git a/server/osa/domain/validation/util/di/provider.py b/server/osa/domain/validation/util/di/provider.py index 1a9235f..707675d 100644 --- a/server/osa/domain/validation/util/di/provider.py +++ b/server/osa/domain/validation/util/di/provider.py @@ -3,12 +3,14 @@ from osa.config import Config from osa.domain.shared.model.srn import Domain from osa.domain.validation.service import ValidationService +from osa.domain.validation.service.hook import HookService from osa.util.di.base import Provider from osa.util.di.scope import Scope class ValidationProvider(Provider): service = provide(ValidationService, scope=Scope.UOW) + hook_service = provide(HookService, scope=Scope.UOW) @provide(scope=Scope.UOW) def get_node_domain(self, config: Config) -> Domain: diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index 286b51e..fb665c4 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -5,17 +5,21 @@ from dishka import AsyncContainer, provide +from osa.config import Config from osa.domain.curation.handler import AutoApproveCuration -from osa.domain.deposition.handler import CreateDepositionFromSource, ReturnToDraft -from osa.domain.feature.handler import CreateFeatureTables, InsertRecordFeatures +from osa.domain.deposition.handler import ReturnToDraft +from osa.domain.feature.handler import ( + CreateFeatureTables, + InsertBatchFeatures, + InsertRecordFeatures, +) +from osa.domain.ingest.handler import PublishBatch, RunHooks, RunIngester from osa.domain.record.handler import ConvertDepositionToRecord from osa.domain.shared.event import EventHandler from osa.domain.shared.event_log import EventLog from osa.domain.shared.model.subscription_registry import SubscriptionRegistry from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.event_repository import EventRepository -from osa.domain.source.handler import PullFromSource, TriggerInitialSourceRun -from osa.domain.source.schedule import SourceSchedule from osa.domain.validation.handler import ValidateDeposition from osa.infrastructure.event.worker import WorkerPool from osa.util.di.base import Provider @@ -32,13 +36,14 @@ # Feature handlers (must run before source triggers) CreateFeatureTables, InsertRecordFeatures, - # Source handlers - TriggerInitialSourceRun, - PullFromSource, + InsertBatchFeatures, + # Ingest handlers + RunIngester, + RunHooks, + PublishBatch, # Validation handlers ValidateDeposition, # Deposition handlers - CreateDepositionFromSource, ReturnToDraft, # Curation handlers AutoApproveCuration, @@ -107,9 +112,6 @@ def get_outbox(self, repo: EventRepository, registry: SubscriptionRegistry) -> O def get_event_log(self, repo: EventRepository) -> EventLog: return EventLog(repo) - # UOW-scoped provider for SourceSchedule - source_schedule = provide(SourceSchedule, scope=Scope.UOW) - @provide(scope=Scope.APP) def get_handler_types(self) -> HandlerTypes: """Return all handler types (core + extra) for WorkerPool registration.""" @@ -130,12 +132,13 @@ def get_worker_pool( self, container: AsyncContainer, handler_types: HandlerTypes, + config: Config, ) -> WorkerPool: """WorkerPool with pull-based event handlers.""" pool = WorkerPool(container=container, stale_claim_interval=60.0) for handler_type in handler_types: - pool.register(handler_type) + pool.register(handler_type, config=config) logger.info(f"WorkerPool created with {len(pool.workers)} workers") return pool diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 24924f6..7721c2a 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -4,7 +4,10 @@ import logging from contextlib import AsyncExitStack from dataclasses import dataclass, field -from typing import Any, NewType +from typing import TYPE_CHECKING, Any, NewType + +if TYPE_CHECKING: + from osa.config import Config from apscheduler import AsyncScheduler from apscheduler.triggers.cron import CronTrigger @@ -45,8 +48,10 @@ class name as its consumer_group. Deliveries are claimed per consumer group, enabling multiple handlers to independently process the same event. """ - def __init__(self, handler_type: type[EventHandler[Any]]) -> None: + def __init__(self, handler_type: type[EventHandler[Any]], *, instance_id: int = 0) -> None: self._handler_type = handler_type + self._instance_id = instance_id + # All instances share the same consumer group so SKIP LOCKED distributes work self._consumer_group = handler_type.__name__ # Read config from handler classvars @@ -73,8 +78,10 @@ def __init__(self, handler_type: type[EventHandler[Any]]) -> None: @property def name(self) -> str: - """Worker name (handler class name).""" - return self._handler_type.__name__ + """Worker name (handler class name + instance suffix if concurrent).""" + if self._instance_id == 0: + return self._handler_type.__name__ + return f"{self._handler_type.__name__}-{self._instance_id}" @property def consumer_group(self) -> str: @@ -105,9 +112,11 @@ def start(self) -> asyncio.Task: if self._container is None: raise RuntimeError("Container not set. Call set_container() first.") + import logfire + self._shutdown = False self._task = asyncio.create_task(self._run(), name=f"worker-{self.name}") - logger.info(f"Worker '{self.name}' started") + logfire.info("worker started: {worker_name}", worker_name=self.name) return self._task def stop(self) -> None: @@ -236,14 +245,45 @@ def workers(self) -> list[Worker]: """List of managed workers.""" return self._workers - def register(self, handler_type: type[EventHandler[Any]]) -> Worker: - """Register an EventHandler type and create a Worker for it.""" - worker = Worker(handler_type) - if self._container is not None: - worker.set_container(self._container) - self._workers.append(worker) - logger.debug(f"Registered handler '{handler_type.__name__}' as worker") - return worker + def register( + self, + handler_type: type[EventHandler[Any]], + config: "Config | None" = None, + ) -> Worker: + """Register an EventHandler type and create Worker(s) for it. + + Concurrency is determined by (in priority order): + 1. Config override (e.g. ``config.worker.hook_concurrency`` for RunHooks) + 2. Handler classvar ``__concurrency__`` + 3. Default of 1 + + Multiple workers share the same consumer group so deliveries are + distributed across them via FOR UPDATE SKIP LOCKED. + """ + concurrency = getattr(handler_type, "__concurrency__", 1) + + # Apply config overrides + if config is not None: + from osa.domain.ingest.handler.run_hooks import RunHooks + + if handler_type is RunHooks: + concurrency = config.worker.hook_concurrency + + first_worker = None + for i in range(concurrency): + worker = Worker(handler_type, instance_id=i) + if self._container is not None: + worker.set_container(self._container) + self._workers.append(worker) + if first_worker is None: + first_worker = worker + if concurrency > 1: + logger.debug( + f"Registered handler '{handler_type.__name__}' with {concurrency} concurrent workers" + ) + else: + logger.debug(f"Registered handler '{handler_type.__name__}' as worker") + return first_worker # type: ignore[return-value] def add_worker(self, worker: Worker) -> None: """Add a worker to the pool.""" @@ -311,36 +351,7 @@ async def start(self) -> None: async def _build_schedules_from_conventions(self) -> list[ScheduleConfig]: """Query conventions with sources and build schedule configs.""" - if self._container is None: - return [] - - from osa.domain.deposition.service.convention import ConventionService - from osa.domain.source.schedule import SourceSchedule as SourceScheduleType - - configs: list[ScheduleConfig] = [] - try: - async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: - convention_service = await scope.get(ConventionService) - conventions = await convention_service.list_conventions_with_source() - - for conv in conventions: - if conv.source is None or conv.source.schedule is None: - continue - configs.append( - ScheduleConfig( - schedule_type=SourceScheduleType, - cron=conv.source.schedule.cron, - id=f"source-{conv.srn}", - params={ - "convention": str(conv.srn), - "limit": conv.source.schedule.limit, - }, - ) - ) - except Exception as e: - logger.warning(f"Failed to build schedules from conventions: {e}") - - return configs + return [] async def stop(self, timeout: float = 30.0) -> None: """Stop all workers gracefully.""" diff --git a/server/osa/infrastructure/ingest/__init__.py b/server/osa/infrastructure/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/infrastructure/ingest/di.py b/server/osa/infrastructure/ingest/di.py new file mode 100644 index 0000000..045ac82 --- /dev/null +++ b/server/osa/infrastructure/ingest/di.py @@ -0,0 +1,45 @@ +"""Dependency injection provider for ingest domain.""" + +from dishka import provide +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.command.start_ingest import StartIngestHandler +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.model.srn import Domain +from osa.domain.shared.outbox import Outbox +from osa.infrastructure.persistence.repository.ingest import PostgresIngestRunRepository +from osa.infrastructure.storage.layout import StorageLayout +from osa.util.di.base import Provider +from osa.util.di.scope import Scope +from osa.util.paths import OSAPaths + + +class IngestProvider(Provider): + """Provides IngestService, IngestRunRepository, StorageLayout, and StartIngestHandler.""" + + @provide(scope=Scope.APP) + def get_storage_layout(self, paths: OSAPaths) -> StorageLayout: + return StorageLayout(paths.data_dir) + + @provide(scope=Scope.UOW) + def get_ingest_repo(self, session: AsyncSession) -> IngestRunRepository: + return PostgresIngestRunRepository(session) + + @provide(scope=Scope.UOW) + def get_ingest_service( + self, + ingest_repo: IngestRunRepository, + convention_service: ConventionService, + outbox: Outbox, + node_domain: Domain, + ) -> IngestService: + return IngestService( + ingest_repo=ingest_repo, + convention_service=convention_service, + outbox=outbox, + node_domain=node_domain, + ) + + start_ingest_handler = provide(StartIngestHandler, scope=Scope.UOW) diff --git a/server/osa/infrastructure/k8s/__init__.py b/server/osa/infrastructure/k8s/__init__.py index 2511c39..36bc8cf 100644 --- a/server/osa/infrastructure/k8s/__init__.py +++ b/server/osa/infrastructure/k8s/__init__.py @@ -1,6 +1,6 @@ """Kubernetes runner infrastructure. kubernetes-asyncio is an optional dependency. Modules that require it -(di.py, runner.py, source_runner.py, health.py) perform lazy imports +(di.py, runner.py, ingester_runner.py, health.py) perform lazy imports and raise ConfigurationError if the package is not installed. """ diff --git a/server/osa/infrastructure/k8s/di.py b/server/osa/infrastructure/k8s/di.py index ebde8d3..1b6d99b 100644 --- a/server/osa/infrastructure/k8s/di.py +++ b/server/osa/infrastructure/k8s/di.py @@ -13,10 +13,10 @@ from dishka import activate, provide from osa.config import Config -from osa.domain.source.port.source_runner import SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterRunner from osa.domain.validation.port.hook_runner import HookRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner from osa.infrastructure.s3.client import S3Client from osa.util.di.base import Provider from osa.util.di.markers import K8S @@ -62,12 +62,12 @@ def get_hook_runner_oci( return OciHookRunner(docker=docker, host_data_dir=config.host_data_dir) @provide(scope=Scope.UOW) - def get_source_runner_oci( + def get_ingester_runner_oci( self, docker: aiodocker.Docker, config: Config, - ) -> SourceRunner: - return OciSourceRunner(docker=docker, host_data_dir=config.host_data_dir) + ) -> IngesterRunner: + return OciIngesterRunner(docker=docker, host_data_dir=config.host_data_dir) # ------------------------------------------------------------------ # K8s backend (activated when config.runner.backend == "k8s") @@ -133,12 +133,12 @@ def get_hook_runner_k8s( return K8sHookRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) @provide(when=K8S, scope=Scope.UOW) - def get_source_runner_k8s( + def get_ingester_runner_k8s( self, k8s_api_client: ApiClient, config: Config, s3: S3Client, - ) -> SourceRunner: - from osa.infrastructure.k8s.source_runner import K8sSourceRunner + ) -> IngesterRunner: + from osa.infrastructure.k8s.ingester_runner import K8sIngesterRunner - return K8sSourceRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) + return K8sIngesterRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/ingester_runner.py similarity index 88% rename from server/osa/infrastructure/k8s/source_runner.py rename to server/osa/infrastructure/k8s/ingester_runner.py index 39cb877..fa29ec1 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/ingester_runner.py @@ -1,4 +1,4 @@ -"""Kubernetes Job-based source runner.""" +"""Kubernetes Job-based ingester runner.""" from __future__ import annotations @@ -11,9 +11,9 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError, InfrastructureError -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label from osa.infrastructure.runner_utils import ( @@ -31,7 +31,7 @@ SCHEDULING_TIMEOUT = 120 -class K8sSourceRunner(SourceRunner): +class K8sIngesterRunner(IngesterRunner): """Executes sources as Kubernetes Jobs. Key differences from K8sHookRunner: @@ -54,11 +54,11 @@ def _s3_prefix(self, work_dir: Path, subdir: str) -> str: async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + ingester: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: + ) -> IngesterOutput: try: from kubernetes_asyncio.client import BatchV1Api, CoreV1Api except ImportError: @@ -74,8 +74,8 @@ async def run( # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") - if inputs.config or source.config: - config = {**(source.config or {}), **(inputs.config or {})} + if inputs.config or ingester.config: + config = {**(ingester.config or {}), **(inputs.config or {})} await self._s3.put_object(f"{input_prefix}/config.json", json.dumps(config)) if inputs.session: @@ -84,7 +84,7 @@ async def run( return await self._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -95,21 +95,21 @@ async def _run_job( self, batch_api: BatchV1Api, core_api: CoreV1Api, - source: SourceDefinition, - inputs: SourceInputs, + ingester: IngesterDefinition, + inputs: IngesterInputs, work_dir: Path, files_dir: Path, *, convention_srn: ConventionSRN | None = None, - ) -> SourceOutput: - """Core Job lifecycle for source execution.""" + ) -> IngesterOutput: + """Core Job lifecycle for ingester execution.""" namespace = self._config.namespace job_name_to_watch = None try: # Check for existing Jobs existing = await self._check_existing_job( - batch_api, namespace, convention_srn, source.digest + batch_api, namespace, convention_srn, ingester.digest ) if existing == "succeeded": @@ -119,7 +119,7 @@ async def _run_job( job_name_to_watch = existing.split(":", 1)[1] else: spec = self._build_job_spec( - source, + ingester, work_dir=work_dir, files_dir=files_dir, inputs=inputs, @@ -129,11 +129,11 @@ async def _run_job( await batch_api.create_namespaced_job(namespace, spec) logger.info( - "Created K8s source Job", + "Created K8s ingester Job", extra={ "job_name": job_name_to_watch, "namespace": namespace, - "image": f"{source.image}@{source.digest}", + "image": f"{ingester.image}@{ingester.digest}", }, ) @@ -145,7 +145,7 @@ async def _run_job( batch_api, job_name_to_watch, namespace, - timeout_seconds=source.limits.timeout_seconds + 30, + timeout_seconds=ingester.limits.timeout_seconds + 30, ) if result == "succeeded": @@ -161,7 +161,7 @@ async def _run_job( return output # Failed — diagnose and raise - await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, source, result) + await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, ingester, result) # unreachable but satisfies type checker raise ExternalServiceError("Source failed") @@ -169,7 +169,7 @@ async def _run_job( if job_name_to_watch: await self._cleanup_job(batch_api, job_name_to_watch, namespace) - async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceOutput: + async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> IngesterOutput: from osa.infrastructure.runner_utils import ( parse_records_from_s3, parse_session_from_s3, @@ -178,7 +178,7 @@ async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceO output_prefix = self._s3_prefix(work_dir, "output") records = await parse_records_from_s3(self._s3, output_prefix) session = await parse_session_from_s3(self._s3, output_prefix) - return SourceOutput(records=records, session=session, files_dir=files_dir) + return IngesterOutput(records=records, session=session, files_dir=files_dir) async def _check_existing_job( self, @@ -187,7 +187,7 @@ async def _check_existing_job( convention_srn: ConventionSRN | None, digest: str = "", ) -> str | None: - label_parts = ["osa.io/role=source"] + label_parts = ["osa.io/role=ingester"] if convention_srn is not None: label_parts.append(f"osa.io/convention={label_value(convention_srn)}") if digest: @@ -208,11 +208,11 @@ async def _check_existing_job( def _build_job_spec( self, - source: SourceDefinition, + ingester: IngesterDefinition, *, work_dir: Path, files_dir: Path, - inputs: SourceInputs | None = None, + inputs: IngesterInputs | None = None, convention_srn: ConventionSRN | None = None, ) -> V1Job: from kubernetes_asyncio.client import ( @@ -234,15 +234,15 @@ def _build_job_spec( V1VolumeMount, ) - name = job_name("source", "src", str(convention_srn) if convention_srn else "unknown") + name = job_name("ingester", "ing", str(convention_srn) if convention_srn else "unknown") relative_work = self._relative_path(work_dir) input_subpath = f"{relative_work}/input" output_subpath = f"{relative_work}/output" relative_files = self._relative_path(files_dir) labels: dict[str, str] = { - "osa.io/role": "source", - "osa.io/digest": sanitize_label(source.digest), + "osa.io/role": "ingester", + "osa.io/digest": sanitize_label(ingester.digest), } if convention_srn is not None: labels["osa.io/convention"] = label_value(convention_srn) @@ -278,13 +278,13 @@ def _build_job_spec( ] container = V1Container( - name="source", - image=f"{source.image}@{source.digest}", + name="ingester", + image=f"{ingester.image}@{ingester.digest}", env=env, resources=V1ResourceRequirements( limits={ - "memory": to_k8s_quantity(source.limits.memory), - "cpu": source.limits.cpu, + "memory": to_k8s_quantity(ingester.limits.memory), + "cpu": ingester.limits.cpu, }, ), security_context=V1SecurityContext( @@ -320,7 +320,7 @@ def _build_job_spec( metadata=V1ObjectMeta(name=name, namespace=self._config.namespace, labels=labels), spec=V1JobSpec( backoff_limit=0, - active_deadline_seconds=SCHEDULING_TIMEOUT + source.limits.timeout_seconds, + active_deadline_seconds=SCHEDULING_TIMEOUT + ingester.limits.timeout_seconds, ttl_seconds_after_finished=self._config.job_ttl_seconds, template=V1PodTemplateSpec( metadata=V1ObjectMeta(labels=labels), @@ -420,12 +420,14 @@ async def _diagnose_and_raise( core_api: CoreV1Api, job_name: str, namespace: str, - source: SourceDefinition, + ingester: IngesterDefinition, failure_info: str, ) -> None: """Determine failure reason and raise appropriate error.""" if "DeadlineExceeded" in failure_info: - raise ExternalServiceError(f"Source timed out after {source.limits.timeout_seconds}s") + raise ExternalServiceError( + f"Ingester timed out after {ingester.limits.timeout_seconds}s" + ) try: label_selector = f"job-name={job_name}" @@ -457,4 +459,4 @@ async def _cleanup_job(self, batch_api: BatchV1Api, job_name: str, namespace: st except Exception as exc: if getattr(exc, "status", None) == 404: return - logger.warning("Failed to clean up K8s source Job", extra={"job_name": job_name}) + logger.warning("Failed to clean up K8s ingester Job", extra={"job_name": job_name}) diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 457c799..6d5ca8a 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -71,7 +71,9 @@ async def run( # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") - await self._s3.put_object(f"{input_prefix}/record.json", json.dumps(inputs.record_json)) + # Write records.jsonl (unified batch contract) + records_jsonl = "\n".join(json.dumps(r.model_dump()) for r in inputs.records) + "\n" + await self._s3.put_object(f"{input_prefix}/records.jsonl", records_jsonl) if inputs.config or hook.runtime.config: config = {**hook.runtime.config, **(inputs.config or {})} await self._s3.put_object(f"{input_prefix}/config.json", json.dumps(config)) @@ -113,11 +115,17 @@ async def _run_job( job_name_to_watch = existing.split(":", 1)[1] else: # Create new Job (no existing or failed) + # Mount the parent of all per-record file dirs — works for + # both depositions (one subdir) and ingests (N subdirs) + files_dir = None + if inputs.files_dirs: + first_dir = next(iter(inputs.files_dirs.values())) + files_dir = first_dir.parent spec = self._build_job_spec( hook, work_dir, run_id=inputs.run_id, - files_dir=inputs.files_dir, + files_dir=files_dir, ) job_name_to_watch = spec.metadata.name @@ -262,6 +270,7 @@ def _build_job_spec( V1VolumeMount(name="tmp", mount_path="/tmp"), ] + # Mount per-record file directories (ingest: multiple, deposition: one) if files_dir: relative_files = self._relative_path(files_dir) mounts.append( @@ -468,7 +477,7 @@ async def _diagnose_failure( if getattr(terminated, "reason", None) == "OOMKilled": return HookResult( hook_name=hook.name, - status=HookStatus.FAILED, + status=HookStatus.OOM, error_message="Hook killed by OOM", duration_seconds=duration, ) diff --git a/server/osa/infrastructure/logging.py b/server/osa/infrastructure/logging.py new file mode 100644 index 0000000..90070b0 --- /dev/null +++ b/server/osa/infrastructure/logging.py @@ -0,0 +1,168 @@ +"""OSA logging — custom logfire console exporter and structured logger. + +Provides: +- ``OSAConsoleExporter``: logfire console formatter with aligned columns + (timestamp, level, module, message) and indented continuation lines. +- ``get_logger(name)``: thin wrapper around logfire that auto-tags with module name. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any, cast + +import logfire as _logfire +from logfire._internal.exporters.console import ( + ATTRIBUTES_TAGS_KEY, + ONE_SECOND_IN_NANOSECONDS, + SimpleConsoleSpanExporter, + _ERROR_LEVEL, + _WARN_LEVEL, +) + +if TYPE_CHECKING: + from logfire._internal.exporters.console import Record, TextParts + + +_LEVEL_NAMES: dict[int, str] = { + 0: "TRACE", + 1: "DEBUG", + 5: "DEBUG", + 9: "INFO ", + 13: "WARN ", + 17: "ERROR", + 21: "FATAL", +} + +# Module column width (inside brackets) +_MODULE_WIDTH = 20 +# Fixed prefix: HH:MM:SS.mmm(12) + space(1) + LEVEL(5) + space(1) + [module](22) + space(1) = 42 +_PREFIX_WIDTH = 42 + + +class OSAConsoleExporter(SimpleConsoleSpanExporter): + """Logfire console exporter with aligned columns. + + Format: ``HH:MM:SS.mmm LEVEL module.name message`` + + Continuation lines are indented to align with the message column. + Module name is extracted from ``_tags`` and shortened for readability. + """ + + def _span_text_parts(self, span: Record, indent: int) -> tuple[str, TextParts]: + parts: TextParts = [] + + # Timestamp + if self._include_timestamp: + ts = datetime.fromtimestamp(span.timestamp / ONE_SECOND_IN_NANOSECONDS) + ts_str = f"{ts:%H:%M:%S.%f}"[:-3] + parts += [(ts_str, "green"), (" ", "")] + + # Level (fixed 5 chars) + level: int = span.level + level_name = _LEVEL_NAMES.get(level, f"L{level:<3}") + if level >= _ERROR_LEVEL: + level_style = "red" + elif level >= _WARN_LEVEL: + level_style = "yellow" + else: + level_style = "dim" + parts += [(level_name, level_style), (" ", "")] + + # Module tag in brackets, fixed width + if self._include_tags: + tags = span.attributes.get(ATTRIBUTES_TAGS_KEY) + if tags: + tag = cast("list[str]", tags)[0] + short = _shorten_module(tag) + bracketed = f"[{short}]" + parts += [(f"{bracketed:<{_MODULE_WIDTH + 2}}", "cyan"), (" ", "")] + else: + parts += [(" " * (_MODULE_WIDTH + 3), "")] + + if indent: + parts += [(indent * " ", "")] + + # Message with aligned continuation lines + msg: str = span.message + pad = " " * _PREFIX_WIDTH + msg = msg.replace("\n", "\n" + pad) + + if level >= _ERROR_LEVEL: + parts += [(msg, "red")] + elif level >= _WARN_LEVEL: + parts += [(msg, "yellow")] + else: + parts += [(msg, "")] + + return msg, parts + + +def _shorten_module(name: str) -> str: + """Shorten module path to fit ~20 chars. + + ``osa.domain.ingest.handler.run_ingester`` → ``ingest.run_ingester`` + ``osa.infrastructure.oci.runner`` → ``infra.oci.runner`` + ``osa.domain.feature.handler.insert_batch_features`` → ``feat.ins_batch_feat`` + """ + short = ( + name.replace("osa.domain.", "") + .replace("osa.infrastructure.", "infra.") + .replace(".handler.", ".") + .replace(".service.", ".") + .replace(".util.", ".") + ) + if len(short) <= _MODULE_WIDTH: + return short + # Truncate: keep first and last segment, abbreviate middle + parts = short.split(".") + if len(parts) <= 2: + return short[:_MODULE_WIDTH] + # Keep first and last, drop middle segments until it fits + first, *middle, last = parts + while middle and len(f"{first}.{'.'.join(middle)}.{last}") > _MODULE_WIDTH: + middle.pop(0) + result = f"{first}.{'.'.join(middle)}.{last}" if middle else f"{first}.{last}" + return result[:_MODULE_WIDTH] + + +class Logger: + """Structured logger that wraps logfire with automatic module tagging. + + Usage:: + + from osa.infrastructure.logging import get_logger + + log = get_logger(__name__) + log.info("batch {idx}: pulled {n} records", idx=0, n=101) + + Produces structured logfire spans with ``_tags=[module_name]``, + preserving key-value attributes for logfire cloud while showing + the module name in console output. + """ + + __slots__ = ("_tags",) + + def __init__(self, name: str) -> None: + self._tags = [name] + + def info(self, msg: str, **kwargs: Any) -> None: + _logfire.info(msg, _tags=self._tags, **kwargs) + + def warn(self, msg: str, **kwargs: Any) -> None: + _logfire.warn(msg, _tags=self._tags, **kwargs) + + def error(self, msg: str, **kwargs: Any) -> None: + _logfire.error(msg, _tags=self._tags, **kwargs) + + def debug(self, msg: str, **kwargs: Any) -> None: + _logfire.debug(msg, _tags=self._tags, **kwargs) + + +def get_logger(name: str) -> Logger: + """Create a structured logger for a module. + + Args: + name: Module name, typically ``__name__``. + """ + return Logger(name) diff --git a/server/osa/infrastructure/oci/di.py b/server/osa/infrastructure/oci/di.py index 5c51ceb..973a158 100644 --- a/server/osa/infrastructure/oci/di.py +++ b/server/osa/infrastructure/oci/di.py @@ -4,10 +4,10 @@ from dishka import provide from osa.config import Config -from osa.domain.source.port.source_runner import SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterRunner from osa.domain.validation.port.hook_runner import HookRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -24,5 +24,5 @@ def get_hook_runner(self, docker: aiodocker.Docker, config: Config) -> HookRunne return OciHookRunner(docker=docker, host_data_dir=config.host_data_dir) @provide(scope=Scope.UOW) - def get_source_runner(self, docker: aiodocker.Docker, config: Config) -> SourceRunner: - return OciSourceRunner(docker=docker, host_data_dir=config.host_data_dir) + def get_ingester_runner(self, docker: aiodocker.Docker, config: Config) -> IngesterRunner: + return OciIngesterRunner(docker=docker, host_data_dir=config.host_data_dir) diff --git a/server/osa/infrastructure/oci/source_runner.py b/server/osa/infrastructure/oci/ingester_runner.py similarity index 76% rename from server/osa/infrastructure/oci/source_runner.py rename to server/osa/infrastructure/oci/ingester_runner.py index 2f261ce..bb926ee 100644 --- a/server/osa/infrastructure/oci/source_runner.py +++ b/server/osa/infrastructure/oci/ingester_runner.py @@ -1,4 +1,4 @@ -"""OCI source runner using aiodocker.""" +"""OCI ingester runner using aiodocker.""" import asyncio import json @@ -11,8 +11,8 @@ import logfire from osa.domain.shared.error import ExternalServiceError -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.model.source import IngesterDefinition +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.runner_utils import ( parse_memory, parse_records_file, @@ -20,13 +20,13 @@ ) -class OciSourceRunner(SourceRunner): - """Executes sources in OCI containers via aiodocker. +class OciIngesterRunner(IngesterRunner): + """Executes ingesters in OCI containers via aiodocker. Key differences from OciHookRunner: - - Network access enabled (sources call upstream APIs) + - Network access enabled (ingesters call upstream APIs) - Three bind mounts: $OSA_IN (ro), $OSA_OUT (rw), $OSA_FILES (rw) - - No ReadonlyRootfs (sources may need writable FS for pip cache, etc.) + - No ReadonlyRootfs (ingesters may need writable FS for pip cache, etc.) - Higher default limits (3600s timeout, 4g memory) - Output is records.jsonl (line-delimited JSON), not features.json @@ -47,12 +47,12 @@ def __init__( async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + ingester: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: - timeout = source.limits.timeout_seconds + ) -> IngesterOutput: + timeout = ingester.limits.timeout_seconds from shutil import rmtree @@ -68,8 +68,8 @@ def _force_remove(func, path, exc): container_output = work_dir / "output" container_output.mkdir(parents=True, exist_ok=True) try: - if inputs.config or source.config: - config = {**(source.config or {}), **(inputs.config or {})} + if inputs.config or ingester.config: + config = {**(ingester.config or {}), **(inputs.config or {})} (staging_dir / "config.json").write_text(json.dumps(config)) if inputs.session: @@ -80,9 +80,9 @@ def _force_remove(func, path, exc): try: async def _resolve_and_run(): - image_ref = await self._resolve_image(source.image, source.digest) + image_ref = await self._resolve_image(ingester.image, ingester.digest) return await self._run_container( - image_ref, staging_dir, files_dir, container_output, source, inputs + image_ref, staging_dir, files_dir, container_output, ingester, inputs ) result = await asyncio.wait_for( @@ -93,12 +93,12 @@ async def _resolve_and_run(): except asyncio.TimeoutError: duration = time.monotonic() - start_time logfire.error( - "Source timed out", - image=source.image, + "Ingester timed out", + image=ingester.image, timeout=timeout, duration=duration, ) - raise ExternalServiceError(f"Source timed out after {timeout}s") + raise ExternalServiceError(f"Ingester timed out after {timeout}s") finally: rmtree(staging_dir, onexc=_force_remove) @@ -108,9 +108,9 @@ async def _run_container( staging_dir: Path, files_dir: Path, output_dir: Path, - source: SourceDefinition, - inputs: SourceInputs, - ) -> SourceOutput: + ingester: IngesterDefinition, + inputs: IngesterInputs, + ) -> IngesterOutput: container = None try: # Build env vars @@ -137,11 +137,11 @@ async def _run_container( "Env": env, "HostConfig": { "Binds": binds, - "Memory": parse_memory(source.limits.memory), - "MemorySwap": parse_memory(source.limits.memory), - "NanoCpus": int(float(source.limits.cpu) * 1e9), + "Memory": parse_memory(ingester.limits.memory), + "MemorySwap": parse_memory(ingester.limits.memory), + "NanoCpus": int(float(ingester.limits.cpu) * 1e9), # No NetworkMode: "none" — sources need network access - # No ReadonlyRootfs — sources may need writable FS + # No ReadonlyRootfs — ingesters may need writable FS "CapDrop": ["ALL"], "SecurityOpt": ["no-new-privileges"], "PidsLimit": 256, @@ -159,19 +159,21 @@ async def _run_container( oom_killed = inspect_data.get("State", {}).get("OOMKilled", False) if oom_killed: - raise ExternalServiceError("Source killed by OOM") + raise ExternalServiceError("Ingester killed by OOM") if exit_code != 0: logs = await container.log(stdout=True, stderr=True) logs_str = "".join(logs) if logs else "" - raise ExternalServiceError(f"Source exited with code {exit_code}: {logs_str[:500]}") + raise ExternalServiceError( + f"Ingester exited with code {exit_code}: {logs_str[:500]}" + ) records = parse_records_file(output_dir) session = parse_session_file(output_dir) - return SourceOutput(records=records, session=session, files_dir=files_dir) + return IngesterOutput(records=records, session=session, files_dir=files_dir) except aiodocker.DockerError as e: - logfire.error("Docker error running source", error=str(e)) + logfire.error("Docker error running ingester", error=str(e)) raise ExternalServiceError(f"Docker error: {e}") from e finally: if container is not None: @@ -210,6 +212,6 @@ async def _resolve_image(self, image: str, digest: str) -> str: pass # Pull from registry as last resort - logfire.info("Pulling source image", image=image) + logfire.info("Pulling ingester image", image=image) await self._docker.images.pull(image) return image diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index 7878991..05414d3 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -4,16 +4,16 @@ import json import os import stat +import sys import time from pathlib import Path from shutil import rmtree import aiodocker -import logfire - from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.logging import get_logger from osa.infrastructure.runner_utils import ( detect_rejection, parse_memory, @@ -27,6 +27,9 @@ def _force_remove(func, path, exc): func(path) +log = get_logger(__name__) + + class OciHookRunner(HookRunner): """Executes hooks in OCI containers via aiodocker.""" @@ -54,14 +57,19 @@ async def run( container_output = work_dir / "output" container_output.mkdir(parents=True, exist_ok=True) try: - (staging_dir / "record.json").write_text(json.dumps(inputs.record_json)) - # Pre-create files mountpoint so nested bind works with ReadonlyRootfs - (staging_dir / "files").mkdir(exist_ok=True) + # Write records.jsonl (unified batch contract) + with (staging_dir / "records.jsonl").open("w") as f: + for record in inputs.records: + f.write(json.dumps(record.model_dump()) + "\n") if inputs.config or hook.runtime.config: config = {**hook.runtime.config, **(inputs.config or {})} (staging_dir / "config.json").write_text(json.dumps(config)) + # Create files directory structure: $OSA_FILES/{id}/ per record + files_base = staging_dir / "files" + files_base.mkdir(exist_ok=True) + start_time = time.monotonic() try: @@ -69,7 +77,12 @@ async def run( async def _resolve_and_run(): image_ref = await self._resolve_image(hook.runtime.image, hook.runtime.digest) return await self._run_container( - image_ref, staging_dir, inputs.files_dir, container_output, hook + image_ref, + staging_dir, + inputs.files_dirs, + container_output, + hook, + files_base, ) result = await asyncio.wait_for( @@ -87,7 +100,7 @@ async def _resolve_and_run(): ) except asyncio.TimeoutError: duration = time.monotonic() - start_time - logfire.error( + log.error( "Hook timed out", hook=hook.name, run_id=inputs.run_id, @@ -106,19 +119,28 @@ async def _run_container( self, image_ref: str, staging_dir: Path, - files_dir: Path | None, + files_dirs: dict[str, Path], output_dir: Path, hook: HookDefinition, + files_base: Path, ) -> dict: container = None try: - # Nested bind-mounts: staging at /osa/in:ro, files at /osa/in/files:ro + # Bind mounts: staging at /osa/in:ro, files at /osa/files:ro, output at /osa/out:rw binds = [ f"{self._host_path(staging_dir)}:/osa/in:ro", f"{self._host_path(output_dir)}:/osa/out:rw", ] - if files_dir and files_dir.exists(): - binds.append(f"{self._host_path(files_dir)}:/osa/in/files:ro") + + # Mount per-record file directories under /osa/files/{id}/ + # Sanitize IDs to avoid colons breaking Docker's bind mount syntax + if files_dirs: + for record_id, fdir in files_dirs.items(): + if fdir and fdir.exists(): + safe_id = record_id.replace(":", "_").replace("@", "_") + binds.append(f"{self._host_path(fdir)}:/osa/files/{safe_id}:ro") + elif files_base.exists(): + binds.append(f"{self._host_path(files_base)}:/osa/files:ro") # todo: use pydantic config = { @@ -126,6 +148,7 @@ async def _run_container( "Env": [ "OSA_IN=/osa/in", "OSA_OUT=/osa/out", + "OSA_FILES=/osa/files", f"OSA_HOOK_NAME={hook.name}", ], "User": "65534:65534", @@ -154,9 +177,23 @@ async def _run_container( oom_killed = inspect_data.get("State", {}).get("OOMKilled", False) if oom_killed: + # Grab tail of container logs before deletion + try: + tail_logs = await container.log(stdout=True, stderr=True, tail=3) + tail_text = "".join(tail_logs).strip() if tail_logs else "" + except Exception: + tail_text = "" + log.error( + "OOM: hook={hook_name} limit={memory}", + hook_name=hook.name, + memory=hook.runtime.limits.memory, + ) + if tail_text: + for line in tail_text.splitlines(): + print(f" OOM [{hook.name}] {line}", file=sys.stderr, flush=True) return { - "status": HookStatus.FAILED, - "error_message": "Hook killed by OOM", + "status": HookStatus.OOM, + "error_message": f"Hook killed by OOM (limit: {hook.runtime.limits.memory})", } # Parse progress file @@ -186,13 +223,13 @@ async def _run_container( } except aiodocker.DockerError as e: - logfire.error("Docker error running hook", error=str(e)) + log.error("Docker error running hook", error=str(e)) return { "status": HookStatus.FAILED, "error_message": f"Docker error: {e}", } except Exception as e: - logfire.error("Unexpected error running hook", error=str(e)) + log.error("Unexpected error running hook", error=str(e)) return { "status": HookStatus.FAILED, "error_message": f"Unexpected error: {e}", @@ -229,6 +266,6 @@ async def _resolve_image(self, image: str, digest: str) -> str: pass # Pull from registry as last resort - logfire.info("Pulling hook image", image=image) + log.info("Pulling hook image", image=image) await self._docker.images.pull(image) return image diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index 04c01da..f1e4dcd 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -1,6 +1,7 @@ import hashlib import json import logging +import os import shutil import tempfile from collections.abc import AsyncIterator @@ -12,6 +13,11 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) logger = logging.getLogger(__name__) @@ -19,7 +25,7 @@ class FilesystemStorageAdapter(FileStoragePort): """Local filesystem adapter satisfying all domain storage ports. - Implements FileStoragePort (deposition files), SourceStoragePort, + Implements FileStoragePort (deposition files), HookStoragePort, and FeatureStoragePort via structural subtyping. """ @@ -58,6 +64,8 @@ def get_hook_output_root(self, source_type: str, source_id: str) -> str: if source_type == "deposition": srn = DepositionSRN.parse(source_id) return str(self._dep_dir(srn)) + if source_type == "ingest": + return str(self.base_path / "ingests" / source_id) raise ValueError(f"Unknown source type: {source_type}") async def read_hook_features( @@ -189,3 +197,102 @@ async def move_source_files_to_deposition( # Clean up empty source_id directory if source_files_dir.exists(): source_files_dir.rmdir() + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + """Read JSONL batch outputs from the filesystem, streaming line-by-line.""" + hook_output = Path(output_dir) / "hooks" / hook_name / "output" + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + + _parse_batch_output_files(hook_output, outcomes) + + return outcomes + + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + """Atomically write checkpoint JSONL via os.replace().""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + tmp_path = work_dir / "_checkpoint.jsonl.tmp" + with tmp_path.open("w") as f: + for outcome in outcomes.values(): + f.write(outcome.model_dump_json() + "\n") + os.replace(tmp_path, checkpoint_path) + + def write_batch_outcomes( + self, + work_dir: Path, + outcomes: dict[HookRecordId, BatchRecordOutcome], + ) -> None: + """Write canonical features.jsonl, rejections.jsonl, errors.jsonl.""" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + features: list[str] = [] + rejections: list[str] = [] + errors: list[str] = [] + + for outcome in outcomes.values(): + row: dict[str, Any] = {"id": outcome.record_id} + if outcome.status == OutcomeStatus.PASSED: + row["features"] = outcome.features + features.append(json.dumps(row)) + elif outcome.status == OutcomeStatus.REJECTED: + row["reason"] = outcome.reason + rejections.append(json.dumps(row)) + elif outcome.status == OutcomeStatus.ERRORED: + row["error"] = outcome.error + row["retryable"] = outcome.retryable + errors.append(json.dumps(row)) + + for filename, lines in [ + ("features.jsonl", features), + ("rejections.jsonl", rejections), + ("errors.jsonl", errors), + ]: + if lines: + (output_dir / filename).write_text("\n".join(lines) + "\n") + + +# ── Shared parsing ────────────────────────────────────────────────────── + + +_FILE_STATUS_MAP: list[tuple[str, OutcomeStatus, dict[str, str]]] = [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), +] + + +def _parse_batch_output_files( + output_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] +) -> None: + """Parse features/rejections/errors JSONL files into BatchRecordOutcome dict.""" + for filename, status, field_map in _FILE_STATUS_MAP: + path = output_dir / filename + if not path.exists(): + continue + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSON line in %s", filename) + continue + raw_id = data.get("id") + if not raw_id: + logger.warning("Skipping JSONL line without 'id' in %s", filename) + continue + record_id = HookRecordId(raw_id) + kwargs: dict[str, Any] = { + "record_id": record_id, + "status": status, + } + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 4b4d75b..569945f 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -17,7 +17,6 @@ from osa.domain.record.query.get_record import GetRecordHandler from osa.domain.record.service import RecordService from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader -from osa.domain.source.port.storage import SourceStoragePort from osa.domain.feature.port.storage import FeatureStoragePort from osa.domain.validation.port.storage import HookStoragePort from osa.domain.semantics.port.ontology_repository import OntologyRepository @@ -132,10 +131,6 @@ def get_file_storage_s3(self, config: Config, s3: "S3Client") -> FileStoragePort return S3StorageAdapter(s3=s3, data_mount_path=config.runner.k8s.data_mount_path) - @provide(scope=Scope.APP) - def get_source_storage(self, file_storage: FileStoragePort) -> SourceStoragePort: - return file_storage # type: ignore[return-value] - @provide(scope=Scope.APP) def get_hook_storage(self, file_storage: FileStoragePort) -> HookStoragePort: return file_storage # type: ignore[return-value] diff --git a/server/osa/infrastructure/persistence/repository/convention.py b/server/osa/infrastructure/persistence/repository/convention.py index 570e6a3..626f4cc 100644 --- a/server/osa/infrastructure/persistence/repository/convention.py +++ b/server/osa/infrastructure/persistence/repository/convention.py @@ -7,7 +7,7 @@ from osa.domain.deposition.model.value import FileRequirements from osa.domain.deposition.port.convention_repository import ConventionRepository from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.infrastructure.persistence.tables import conventions_table @@ -20,7 +20,7 @@ def _convention_to_row(convention: Convention) -> dict[str, Any]: "schema_srn": str(convention.schema_srn), "file_requirements": convention.file_requirements.model_dump(), "hooks": [h.model_dump() for h in convention.hooks], - "source": convention.source.model_dump() if convention.source else None, + "source": convention.ingester.model_dump() if convention.ingester else None, "created_at": convention.created_at, } @@ -34,7 +34,7 @@ def _row_to_convention(row: dict[str, Any]) -> Convention: schema_srn=SchemaSRN.parse(row["schema_srn"]), file_requirements=FileRequirements.model_validate(row["file_requirements"]), hooks=[HookDefinition.model_validate(h) for h in (row.get("hooks") or [])], - source=SourceDefinition.model_validate(source_data) if source_data else None, + ingester=IngesterDefinition.model_validate(source_data) if source_data else None, created_at=row["created_at"], ) diff --git a/server/osa/infrastructure/persistence/repository/ingest.py b/server/osa/infrastructure/persistence/repository/ingest.py new file mode 100644 index 0000000..94d9beb --- /dev/null +++ b/server/osa/infrastructure/persistence/repository/ingest.py @@ -0,0 +1,129 @@ +"""PostgreSQL implementation of IngestRunRepository.""" + +import logging + +from sqlalchemy import select, update +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.infrastructure.persistence.tables import ingest_runs_table + +logger = logging.getLogger(__name__) + + +class PostgresIngestRunRepository(IngestRunRepository): + """PostgreSQL implementation with atomic counter updates.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def save(self, ingest_run: IngestRun) -> None: + """Insert or update an ingest run.""" + values = { + "srn": ingest_run.srn, + "convention_srn": ingest_run.convention_srn, + "status": ingest_run.status.value, + "ingestion_finished": ingest_run.ingestion_finished, + "batches_ingested": ingest_run.batches_ingested, + "batches_completed": ingest_run.batches_completed, + "published_count": ingest_run.published_count, + "batch_size": ingest_run.batch_size, + "record_limit": ingest_run.limit, + "started_at": ingest_run.started_at, + "completed_at": ingest_run.completed_at, + } + stmt = ( + insert(ingest_runs_table) + .values(**values) + .on_conflict_do_update( + index_elements=["srn"], + set_=values, + ) + ) + await self._session.execute(stmt) + await self._session.flush() + + async def get(self, srn: str) -> IngestRun | None: + stmt = select(ingest_runs_table).where(ingest_runs_table.c.srn == srn) + result = await self._session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + return _row_to_ingest_run(dict(row)) + + async def get_running_for_convention(self, convention_srn: str) -> IngestRun | None: + stmt = ( + select(ingest_runs_table) + .where(ingest_runs_table.c.convention_srn == convention_srn) + .where( + ingest_runs_table.c.status.in_( + [IngestStatus.PENDING.value, IngestStatus.RUNNING.value] + ) + ) + .limit(1) + ) + result = await self._session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + return _row_to_ingest_run(dict(row)) + + async def increment_batches_ingested( + self, srn: str, *, set_ingestion_finished: bool = False + ) -> IngestRun: + """Atomically increment batches_ingested.""" + t = ingest_runs_table + values = { + "batches_ingested": t.c.batches_ingested + 1, + } + if set_ingestion_finished: + values["ingestion_finished"] = True + + stmt = update(t).where(t.c.srn == srn).values(**values).returning(*t.c) + result = await self._session.execute(stmt) + await self._session.flush() + row = result.mappings().first() + if row is None: + from osa.domain.shared.error import NotFoundError + + raise NotFoundError(f"Ingest run not found: {srn}") + return _row_to_ingest_run(dict(row)) + + async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + """Atomically increment batches_completed and published_count.""" + t = ingest_runs_table + stmt = ( + update(t) + .where(t.c.srn == srn) + .values( + batches_completed=t.c.batches_completed + 1, + published_count=t.c.published_count + published_count, + ) + .returning(*t.c) + ) + result = await self._session.execute(stmt) + await self._session.flush() + row = result.mappings().first() + if row is None: + from osa.domain.shared.error import NotFoundError + + raise NotFoundError(f"Ingest run not found: {srn}") + return _row_to_ingest_run(dict(row)) + + +def _row_to_ingest_run(row: dict) -> IngestRun: + return IngestRun( + srn=row["srn"], + convention_srn=row["convention_srn"], + status=IngestStatus(row["status"]), + ingestion_finished=row["ingestion_finished"], + batches_ingested=row["batches_ingested"], + batches_completed=row["batches_completed"], + published_count=row["published_count"], + batch_size=row["batch_size"], + limit=row.get("record_limit"), + started_at=row["started_at"], + completed_at=row.get("completed_at"), + ) diff --git a/server/osa/infrastructure/persistence/repository/record.py b/server/osa/infrastructure/persistence/repository/record.py index 22763b9..52a0176 100644 --- a/server/osa/infrastructure/persistence/repository/record.py +++ b/server/osa/infrastructure/persistence/repository/record.py @@ -1,6 +1,7 @@ """PostgreSQL implementation of RecordRepository.""" -from sqlalchemy import func, insert, select +from sqlalchemy import func, select, text +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.record.model.aggregate import Record @@ -23,6 +24,30 @@ async def save(self, record: Record) -> None: await self.session.execute(stmt) await self.session.flush() + async def save_many(self, records: list[Record]) -> list[Record]: + """Multi-row INSERT with ON CONFLICT DO NOTHING. + + Returns the records that were actually inserted (duplicates are skipped). + """ + if not records: + return [] + values = [record_to_dict(r) for r in records] + stmt = ( + insert(records_table) + .values(values) + .on_conflict_do_nothing( + index_elements=[ + text("(source->>'type')"), + text("(source->>'id')"), + ], + ) + .returning(records_table.c.srn) + ) + result = await self.session.execute(stmt) + await self.session.flush() + inserted_srns = {row[0] for row in result.fetchall()} + return [r for r in records if str(r.srn) in inserted_srns] + async def get(self, srn: RecordSRN) -> Record | None: """Get a record by SRN.""" stmt = select(records_table).where(records_table.c.srn == str(srn)) diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index de6e670..4f3545a 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -266,7 +266,7 @@ Column("schema_srn", String, nullable=False), # Reference to schemas.srn Column("file_requirements", JSON, nullable=False), # FileRequirements as dict Column("hooks", JSON, nullable=False, default=[]), # List of HookDefinition dicts - Column("source", JSON, nullable=True), # SourceDefinition as dict + Column("source", JSON, nullable=True), # IngesterDefinition as dict Column("created_at", DateTime(timezone=True), nullable=False), ) @@ -304,6 +304,29 @@ Index("ix_role_assignments_user_id", role_assignments_table.c.user_id) +# ============================================================================ +# INGEST RUNS TABLE (Ingest) +# ============================================================================ +ingest_runs_table = Table( + "ingest_runs", + metadata, + Column("srn", String, primary_key=True), + Column("convention_srn", String, ForeignKey("conventions.srn"), nullable=False), + Column("status", String(32), nullable=False, server_default=text("'pending'")), + Column("ingestion_finished", Boolean, nullable=False, server_default=text("false")), + Column("batches_ingested", Integer, nullable=False, server_default=text("0")), + Column("batches_completed", Integer, nullable=False, server_default=text("0")), + Column("published_count", Integer, nullable=False, server_default=text("0")), + Column("batch_size", Integer, nullable=False, server_default=text("1000")), + Column("record_limit", Integer, nullable=True), + Column("started_at", DateTime(timezone=True), nullable=False), + Column("completed_at", DateTime(timezone=True), nullable=True), +) + +Index("idx_ingest_runs_convention", ingest_runs_table.c.convention_srn) +Index("idx_ingest_runs_status", ingest_runs_table.c.status) + + # ============================================================================ # DEVICE AUTHORIZATIONS TABLE (Authentication - OAuth Device Flow) # ============================================================================ diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py index 336a93d..b7340be 100644 --- a/server/osa/infrastructure/runner_utils.py +++ b/server/osa/infrastructure/runner_utils.py @@ -49,26 +49,14 @@ def detect_rejection(progress: list[ProgressEntry]) -> tuple[bool, str | None]: def parse_memory(memory: str) -> int: - """Parse memory string like '2g' or '512m' to bytes.""" - memory = memory.strip().lower() - match = re.match(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$", memory) - if not match: - raise ValueError(f"Invalid memory format: {memory}") + """Parse memory string like '2g' or '512m' to bytes. - amount = float(match.group(1)) - unit = match.group(2) + .. deprecated:: + Use ``osa.domain.shared.model.hook.parse_memory`` instead. + """ + from osa.domain.shared.model.hook import parse_memory as _parse_memory - match unit: - case "g": - return int(amount * 1024 * 1024 * 1024) - case "m": - return int(amount * 1024 * 1024) - case "k": - return int(amount * 1024) - case None: - return int(amount) - case _: - raise ValueError(f"Unknown memory unit: {unit}") + return _parse_memory(memory) _MEMORY_RE = re.compile(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$") @@ -123,7 +111,7 @@ def relative_path(path: Path, data_mount_path: str) -> str: def parse_records_file(output_dir: Path) -> list[dict[str, Any]]: - """Parse records.jsonl from source output directory.""" + """Parse records.jsonl from ingester output directory.""" import logfire records: list[dict[str, Any]] = [] diff --git a/server/osa/infrastructure/s3/storage.py b/server/osa/infrastructure/s3/storage.py index 2b37eda..fc1ebd8 100644 --- a/server/osa/infrastructure/s3/storage.py +++ b/server/osa/infrastructure/s3/storage.py @@ -16,6 +16,11 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError, NotFoundError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) from osa.infrastructure.runner_utils import relative_path from osa.infrastructure.s3.client import S3Client @@ -25,7 +30,7 @@ class S3StorageAdapter(FileStoragePort): """S3-backed adapter satisfying all domain storage ports. - Implements FileStoragePort, SourceStoragePort, HookStoragePort, + Implements FileStoragePort, HookStoragePort, and FeatureStoragePort via structural subtyping — same as FilesystemStorageAdapter but using S3 API instead of POSIX calls. @@ -115,7 +120,7 @@ async def delete_files_for_deposition( prefix = f"{self._dep_prefix(deposition_id)}/" await self._s3.delete_objects(prefix) - # ── SourceStoragePort ──────────────────────────────────────────── + # ── Ingester storage ────────────────────────────────────────────── def get_source_staging_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: """Return path for PVC subpath computation (no I/O).""" @@ -143,7 +148,7 @@ async def move_source_files_to_deposition( source_id: str, deposition_srn: DepositionSRN, ) -> None: - """S3 server-side copy from source staging to deposition files prefix.""" + """S3 server-side copy from ingester staging to deposition files prefix.""" source_prefix = f"{relative_path(staging_dir, self._data_mount_path)}/{source_id}/" dest_prefix = self._files_prefix(deposition_srn) @@ -178,6 +183,8 @@ def get_hook_output_root(self, source_type: str, source_id: str) -> str: srn = DepositionSRN.parse(source_id) safe_id = self._safe_id(srn) return f"{self._data_mount_path}/depositions/{safe_id}" + if source_type == "ingest": + return f"{self._data_mount_path}/ingests/{source_id}" raise ValueError(f"Unknown source type: {source_type}") async def read_hook_features( @@ -200,3 +207,49 @@ async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> prefix = relative_path(Path(hook_output_dir), self._data_mount_path) key = f"{prefix}/hooks/{feature_name}/output/features.json" return await self._s3.head_object(key) + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + """Read JSONL batch outputs from S3.""" + prefix = relative_path(Path(output_dir), self._data_mount_path) + hook_prefix = f"{prefix}/hooks/{hook_name}/output" + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + + file_status_map: list[tuple[str, OutcomeStatus, dict[str, str]]] = [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), + ] + + for filename, status, field_map in file_status_map: + key = f"{hook_prefix}/{filename}" + try: + data_bytes = await self._s3.get_object(key) + except Exception: + continue + + for line in data_bytes.decode().split("\n"): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSON line in %s", filename) + continue + raw_id = data.get("id") + if not raw_id: + logger.warning("Skipping JSONL line without 'id' in %s", filename) + continue + record_id = HookRecordId(raw_id) + kwargs: dict[str, Any] = { + "record_id": record_id, + "status": status, + } + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) + + return outcomes diff --git a/server/osa/infrastructure/source/__init__.py b/server/osa/infrastructure/source/__init__.py deleted file mode 100644 index 701b44d..0000000 --- a/server/osa/infrastructure/source/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source infrastructure - source implementations.""" diff --git a/server/osa/infrastructure/source/di.py b/server/osa/infrastructure/source/di.py deleted file mode 100644 index eda51bd..0000000 --- a/server/osa/infrastructure/source/di.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Dependency injection provider for sources.""" - -from dishka import provide - -from osa.domain.shared.outbox import Outbox -from osa.domain.source.port.source_runner import SourceRunner -from osa.domain.source.port.storage import SourceStoragePort -from osa.domain.source.service import SourceService -from osa.util.di.base import Provider -from osa.util.di.scope import Scope - - -class SourceProvider(Provider): - """Provides SourceService wired with OCI runner.""" - - @provide(scope=Scope.UOW) - def get_source_service( - self, - source_runner: SourceRunner, - source_storage: SourceStoragePort, - outbox: Outbox, - ) -> SourceService: - return SourceService( - source_runner=source_runner, - source_storage=source_storage, - outbox=outbox, - ) diff --git a/server/osa/infrastructure/storage/__init__.py b/server/osa/infrastructure/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/infrastructure/storage/layout.py b/server/osa/infrastructure/storage/layout.py new file mode 100644 index 0000000..a00e734 --- /dev/null +++ b/server/osa/infrastructure/storage/layout.py @@ -0,0 +1,48 @@ +"""Storage layout — single source of truth for directory structure. + +Composable path methods that define where data lives on disk/S3. +Storage adapters and runners consume this instead of hardcoding paths. + +See #106 for the full consolidation plan. Currently covers ingest paths only; +deposition paths will be migrated here in a follow-up. +""" + +from pathlib import Path + + +def _safe_srn(srn: str) -> str: + """Convert an SRN to a filesystem-safe string.""" + return srn.replace(":", "_").replace("@", "_") + + +class StorageLayout: + """Computes storage paths relative to a data root. + + All methods return Path objects. Storage adapters prefix with their + own root (filesystem base_path or S3 key prefix). + """ + + def __init__(self, data_dir: Path) -> None: + self._data_dir = data_dir + + # ── Ingest paths ───────────────────────────────────────────────── + + def ingest_run_dir(self, ingest_run_srn: str) -> Path: + """Root directory for an ingest run.""" + return self._data_dir / "ingests" / _safe_srn(ingest_run_srn) + + def ingest_batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + """Directory for a specific batch within an ingest run.""" + return self.ingest_run_dir(ingest_run_srn) / "batches" / str(batch_index) + + def ingest_batch_ingester_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + """Ingester output directory (records.jsonl, files/) for a batch.""" + return self.ingest_batch_dir(ingest_run_srn, batch_index) / "ingester" + + def ingest_batch_hook_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: + """Hook output directory for a batch.""" + return self.ingest_batch_dir(ingest_run_srn, batch_index) / "hooks" / hook_name + + def ingest_session_file(self, ingest_run_srn: str) -> Path: + """Session state file for ingester continuation.""" + return self.ingest_run_dir(ingest_run_srn) / "session.json" diff --git a/server/tests/integration/persistence/test_convention_repo.py b/server/tests/integration/persistence/test_convention_repo.py index bc3ed27..632eb79 100644 --- a/server/tests/integration/persistence/test_convention_repo.py +++ b/server/tests/integration/persistence/test_convention_repo.py @@ -15,10 +15,10 @@ TableFeatureSpec, ) from osa.domain.shared.model.source import ( + IngesterDefinition, + IngesterLimits, + IngesterScheduleConfig, InitialRunConfig, - SourceDefinition, - SourceLimits, - SourceScheduleConfig, ) from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.infrastructure.persistence.repository.convention import ( @@ -32,7 +32,7 @@ def _make_convention( title: str = "Test Convention", schema_srn: str = "urn:osa:localhost:schema:test-schema-001@1.0.0", hooks: list[HookDefinition] | None = None, - source: SourceDefinition | None = None, + ingester: IngesterDefinition | None = None, ) -> Convention: return Convention( srn=ConventionSRN.parse(srn), @@ -46,7 +46,7 @@ def _make_convention( max_file_size=100_000_000, ), hooks=hooks or [], - source=source, + ingester=ingester, created_at=datetime.now(UTC), ) @@ -70,14 +70,14 @@ def _make_hook() -> HookDefinition: ) -def _make_source() -> SourceDefinition: - return SourceDefinition( - image="ghcr.io/example/source:latest", +def _make_ingester() -> IngesterDefinition: + return IngesterDefinition( + image="ghcr.io/example/ingester:latest", digest="sha256:def456", runner="oci", config={"api_key": "test-key"}, - limits=SourceLimits(timeout_seconds=7200, memory="8g", cpu="4.0"), - schedule=SourceScheduleConfig(cron="0 2 * * *", limit=500), + limits=IngesterLimits(timeout_seconds=7200, memory="8g", cpu="4.0"), + schedule=IngesterScheduleConfig(cron="0 2 * * *", limit=500), initial_run=InitialRunConfig(limit=100), ) @@ -88,8 +88,8 @@ async def test_save_and_get(self, pg_session: AsyncSession): """Save a convention and retrieve it — all fields should match.""" repo = PostgresConventionRepository(pg_session) hook = _make_hook() - source = _make_source() - conv = _make_convention(hooks=[hook], source=source) + ingester = _make_ingester() + conv = _make_convention(hooks=[hook], ingester=ingester) await repo.save(conv) await pg_session.commit() @@ -106,12 +106,12 @@ async def test_save_and_get(self, pg_session: AsyncSession): assert got.hooks[0].runtime.digest == hook.runtime.digest assert got.hooks[0].name == "quality_check" assert got.hooks[0].feature.columns[0].name == "score" - assert got.source is not None - assert got.source.image == source.image - assert got.source.schedule is not None - assert got.source.schedule.cron == "0 2 * * *" - assert got.source.initial_run is not None - assert got.source.initial_run.limit == 100 + assert got.ingester is not None + assert got.ingester.image == ingester.image + assert got.ingester.schedule is not None + assert got.ingester.schedule.cron == "0 2 * * *" + assert got.ingester.initial_run is not None + assert got.ingester.initial_run.limit == 100 async def test_get_nonexistent_returns_none(self, pg_session: AsyncSession): repo = PostgresConventionRepository(pg_session) @@ -163,14 +163,14 @@ async def test_exists_false(self, pg_session: AsyncSession): repo = PostgresConventionRepository(pg_session) assert await repo.exists(ConventionSRN.parse("urn:osa:localhost:conv:nope@1.0.0")) is False - async def test_convention_without_source(self, pg_session: AsyncSession): - """Source is optional — should be None on retrieval when not set.""" + async def test_convention_without_ingester(self, pg_session: AsyncSession): + """Ingester is optional — should be None on retrieval when not set.""" repo = PostgresConventionRepository(pg_session) - conv = _make_convention(source=None, hooks=[]) + conv = _make_convention(ingester=None, hooks=[]) await repo.save(conv) await pg_session.commit() got = await repo.get(conv.srn) assert got is not None - assert got.source is None + assert got.ingester is None assert got.hooks == [] diff --git a/server/tests/unit/application/__init__.py b/server/tests/unit/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/api/__init__.py b/server/tests/unit/application/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/api/v1/__init__.py b/server/tests/unit/application/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/test_app_factory.py b/server/tests/unit/application/test_app_factory.py index 6b276b4..f0233d4 100644 --- a/server/tests/unit/application/test_app_factory.py +++ b/server/tests/unit/application/test_app_factory.py @@ -19,14 +19,14 @@ from osa.application.di import create_container from osa.domain.shared.event import Event, EventHandler, EventId from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.subscription_registry import SubscriptionRegistry -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.event.di import HandlerTypes, _CORE_HANDLERS from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -35,6 +35,7 @@ "OSA_AUTH__JWT__SECRET", "test-secret-that-is-at-least-32-characters-long", ) +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") # --------------------------------------------------------------------------- @@ -49,17 +50,17 @@ async def run(self, hook: HookDefinition, inputs: HookInputs, work_dir: Path) -> return HookResult(hook_name=hook.name, status=HookStatus.PASSED, duration_seconds=0.0) -class StubSourceRunner: - """Stub SourceRunner for testing provider overrides.""" +class StubIngesterRunner: + """Stub IngesterRunner for testing provider overrides.""" async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: - return SourceOutput(records=[], session=None, files_dir=files_dir) + ) -> IngesterOutput: + return IngesterOutput(records=[], session=None, files_dir=files_dir) class StubRunnerProvider(Provider): @@ -70,8 +71,8 @@ def get_hook_runner(self) -> HookRunner: return StubHookRunner() @provide(scope=Scope.UOW, override=True) - def get_source_runner(self) -> SourceRunner: - return StubSourceRunner() + def get_ingester_runner(self) -> IngesterRunner: + return StubIngesterRunner() # --------------------------------------------------------------------------- @@ -110,12 +111,12 @@ def test_runner_override(self): async def resolve(): async with container(scope=Scope.UOW) as uow: hook = await uow.get(HookRunner) - source = await uow.get(SourceRunner) - return hook, source + ingester = await uow.get(IngesterRunner) + return hook, ingester - hook_runner, source_runner = asyncio.run(resolve()) + hook_runner, ingester_runner = asyncio.run(resolve()) assert isinstance(hook_runner, StubHookRunner) - assert isinstance(source_runner, StubSourceRunner) + assert isinstance(ingester_runner, StubIngesterRunner) def test_default_runners_without_override(self): """Without extra providers, default OCI runners are used.""" @@ -124,12 +125,12 @@ def test_default_runners_without_override(self): async def resolve(): async with container(scope=Scope.UOW) as uow: hook = await uow.get(HookRunner) - source = await uow.get(SourceRunner) - return hook, source + ingester = await uow.get(IngesterRunner) + return hook, ingester - hook_runner, source_runner = asyncio.run(resolve()) + hook_runner, ingester_runner = asyncio.run(resolve()) assert isinstance(hook_runner, OciHookRunner) - assert isinstance(source_runner, OciSourceRunner) + assert isinstance(ingester_runner, OciIngesterRunner) # --------------------------------------------------------------------------- diff --git a/server/tests/unit/config/test_config.py b/server/tests/unit/config/test_config.py index 426950f..e5fb1cb 100644 --- a/server/tests/unit/config/test_config.py +++ b/server/tests/unit/config/test_config.py @@ -24,6 +24,7 @@ def config_from_yaml(data: dict, env_overrides: dict[str, str] | None = None) -> raw = make_config_yaml(data) env = { "OSA_AUTH__JWT__SECRET": "test-secret-key-that-is-at-least-32-chars-long", + "OSA_BASE_URL": "http://localhost:8000", **(env_overrides or {}), } diff --git a/server/tests/unit/config/test_paths_config.py b/server/tests/unit/config/test_paths_config.py index 0e5bbc7..8006ebb 100644 --- a/server/tests/unit/config/test_paths_config.py +++ b/server/tests/unit/config/test_paths_config.py @@ -7,6 +7,13 @@ from osa.config import Config +@pytest.fixture(autouse=True) +def _config_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure required config env vars are set for all tests.""" + monkeypatch.setenv("OSA_AUTH__JWT__SECRET", "test-secret-key-that-is-at-least-32-chars-long") + monkeypatch.setenv("OSA_BASE_URL", "http://localhost:8000") + + class TestDatabaseUrlDerivation: """Tests for database URL derivation from OSAPaths.""" diff --git a/server/tests/unit/domain/__init__.py b/server/tests/unit/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/deposition/test_convention_service_v2.py b/server/tests/unit/domain/deposition/test_convention_service_v2.py index fb88cdc..9663e22 100644 --- a/server/tests/unit/domain/deposition/test_convention_service_v2.py +++ b/server/tests/unit/domain/deposition/test_convention_service_v2.py @@ -14,7 +14,7 @@ OciConfig, TableFeatureSpec, ) -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import Domain, SchemaSRN @@ -60,8 +60,8 @@ def _make_hook_def(name: str = "detect_pockets") -> HookDefinition: ) -def _make_source_def() -> SourceDefinition: - return SourceDefinition( +def _make_ingester_def() -> IngesterDefinition: + return IngesterDefinition( image="osa-sources/rcsb-pdb:latest", digest="sha256:abc123", config={"email": "test@example.com", "batch_size": 100}, @@ -129,31 +129,31 @@ async def test_convention_references_created_schema_srn(self): assert result.schema_srn == schema_srn @pytest.mark.asyncio - async def test_convention_saves_source_definition(self): + async def test_convention_saves_ingester_definition(self): service = _make_service() - source = _make_source_def() + ingester = _make_ingester_def() result = await service.create_convention( - title="With Source", + title="With Ingester", version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), - source=source, + ingester=ingester, ) - assert result.source is not None - assert result.source.image == "osa-sources/rcsb-pdb:latest" - assert result.source.digest == "sha256:abc123" - assert result.source.config == {"email": "test@example.com", "batch_size": 100} + assert result.ingester is not None + assert result.ingester.image == "osa-sources/rcsb-pdb:latest" + assert result.ingester.digest == "sha256:abc123" + assert result.ingester.config == {"email": "test@example.com", "batch_size": 100} @pytest.mark.asyncio - async def test_convention_source_defaults_to_none(self): + async def test_convention_ingester_defaults_to_none(self): service = _make_service() result = await service.create_convention( - title="No Source", + title="No Ingester", version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), ) - assert result.source is None + assert result.ingester is None @pytest.mark.asyncio async def test_convention_with_hooks_emits_hooks_in_event(self): @@ -182,7 +182,7 @@ async def test_create_convention_emits_convention_registered(self): version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), - source=_make_source_def(), + ingester=_make_ingester_def(), ) outbox.append.assert_called_once() emitted = outbox.append.call_args[0][0] diff --git a/server/tests/unit/domain/deposition/test_create_deposition_from_source.py b/server/tests/unit/domain/deposition/test_create_deposition_from_source.py deleted file mode 100644 index 44e1e29..0000000 --- a/server/tests/unit/domain/deposition/test_create_deposition_from_source.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Unit tests for CreateDepositionFromSource event handler. - -Tests for User Story 3: Cross-domain decoupling. -""" - -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.handler.create_deposition_from_source import ( - CreateDepositionFromSource, -) -from osa.domain.deposition.model.aggregate import Deposition -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") - - -def _make_dep_srn() -> DepositionSRN: - return DepositionSRN.parse("urn:osa:localhost:dep:test-dep") - - -def _make_event() -> SourceRecordReady: - return SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "4HHB", "title": "Hemoglobin"}, - file_paths=["4HHB/structure.cif"], - source_id="4HHB", - staging_dir="/tmp/staging/run-123", - ) - - -class TestCreateDepositionFromSource: - @pytest.mark.asyncio - async def test_creates_deposition_and_submits(self): - """Handler creates deposition, updates metadata, moves files, and submits.""" - dep = MagicMock(spec=Deposition) - dep.srn = _make_dep_srn() - - deposition_service = AsyncMock() - deposition_service.create.return_value = dep - file_storage = AsyncMock() - - handler = CreateDepositionFromSource( - deposition_service=deposition_service, - file_storage=file_storage, - ) - event = _make_event() - await handler.handle(event) - - # Creates deposition - deposition_service.create.assert_called_once() - create_kwargs = deposition_service.create.call_args[1] - assert create_kwargs["convention_srn"] == event.convention_srn - - # Updates metadata - deposition_service.update_metadata.assert_called_once_with( - srn=dep.srn, - metadata=event.metadata, - ) - - # Moves files - file_storage.move_source_files_to_deposition.assert_called_once() - - # Submits - deposition_service.submit.assert_called_once_with(srn=dep.srn) - - @pytest.mark.asyncio - async def test_uses_system_user_id(self): - """Handler creates deposition with SYSTEM_USER_ID.""" - from osa.domain.auth.model.value import SYSTEM_USER_ID - - dep = MagicMock(spec=Deposition) - dep.srn = _make_dep_srn() - - deposition_service = AsyncMock() - deposition_service.create.return_value = dep - file_storage = AsyncMock() - - handler = CreateDepositionFromSource( - deposition_service=deposition_service, - file_storage=file_storage, - ) - await handler.handle(_make_event()) - - create_kwargs = deposition_service.create.call_args[1] - assert create_kwargs["owner_id"] == SYSTEM_USER_ID diff --git a/server/tests/unit/domain/feature/test_insert_record_features.py b/server/tests/unit/domain/feature/test_insert_record_features.py index 114404a..dbd489e 100644 --- a/server/tests/unit/domain/feature/test_insert_record_features.py +++ b/server/tests/unit/domain/feature/test_insert_record_features.py @@ -9,7 +9,7 @@ from osa.domain.feature.service.feature import FeatureService from osa.domain.record.event.record_published import RecordPublished from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import DepositionSource, HarvestSource +from osa.domain.shared.model.source import DepositionSource, IngestSource from osa.domain.shared.model.srn import ( ConventionSRN, RecordSRN, @@ -198,34 +198,34 @@ async def test_no_features_is_noop(self): feature_store.insert_features.assert_not_called() -class TestInsertRecordFeaturesHarvestSource: - """US2: InsertRecordFeatures works identically for harvest-sourced records.""" +class TestInsertRecordFeaturesIngestSource: + """US2: InsertRecordFeatures works identically for ingest-sourced records.""" @pytest.mark.asyncio - async def test_harvest_source_uses_source_fields(self): + async def test_ingest_source_uses_source_fields(self): """Handler uses source type and id from event regardless of source type.""" feature_service = AsyncMock() storage = MagicMock() - storage.get_hook_output_root.return_value = "/fake/harvest/dir" + storage.get_hook_output_root.return_value = "/fake/ingest/dir" handler = _make_handler(feature_service=feature_service, feature_storage=storage) event = RecordPublished( id=EventId(uuid4()), record_srn=_make_record_srn(), - source=HarvestSource( + source=IngestSource( id="run-123-pdb-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ), - metadata={"title": "Harvested"}, + metadata={"title": "Ingested"}, convention_srn=_make_conv_srn(), expected_features=["pocket_detect"], ) await handler.handle(event) - storage.get_hook_output_root.assert_called_once_with("harvest", "run-123-pdb-456") + storage.get_hook_output_root.assert_called_once_with("ingest", "run-123-pdb-456") feature_service.insert_features_for_record.assert_called_once_with( - hook_output_dir="/fake/harvest/dir", + hook_output_dir="/fake/ingest/dir", record_srn=str(_make_record_srn()), expected_features=["pocket_detect"], ) diff --git a/server/tests/unit/domain/ingest/__init__.py b/server/tests/unit/domain/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/ingest/test_ingest_run.py b/server/tests/unit/domain/ingest/test_ingest_run.py new file mode 100644 index 0000000..55f5395 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingest_run.py @@ -0,0 +1,144 @@ +"""T015/T017: Unit tests for IngestRun aggregate — status transitions, completion, counters.""" + +from datetime import UTC, datetime + +import pytest + +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.shared.error import InvalidStateError + + +def _make_run(**overrides) -> IngestRun: + defaults = { + "srn": "urn:osa:localhost:ing:test-run", + "convention_srn": "urn:osa:localhost:conv:test-conv@1.0.0", + "status": IngestStatus.PENDING, + "started_at": datetime.now(UTC), + } + defaults.update(overrides) + return IngestRun(**defaults) + + +class TestStatusTransitions: + def test_pending_to_running(self) -> None: + run = _make_run() + run.mark_running() + assert run.status == IngestStatus.RUNNING + + def test_running_to_completed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.transition_to(IngestStatus.COMPLETED) + assert run.status == IngestStatus.COMPLETED + + def test_running_to_failed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.mark_failed(datetime.now(UTC)) + assert run.status == IngestStatus.FAILED + assert run.completed_at is not None + + def test_pending_to_failed(self) -> None: + run = _make_run() + run.mark_failed(datetime.now(UTC)) + assert run.status == IngestStatus.FAILED + + def test_completed_to_running_rejected(self) -> None: + run = _make_run(status=IngestStatus.COMPLETED) + with pytest.raises(InvalidStateError, match="Cannot transition"): + run.transition_to(IngestStatus.RUNNING) + + def test_failed_to_running_rejected(self) -> None: + run = _make_run(status=IngestStatus.FAILED) + with pytest.raises(InvalidStateError, match="Cannot transition"): + run.transition_to(IngestStatus.RUNNING) + + def test_completed_to_completed_rejected(self) -> None: + run = _make_run(status=IngestStatus.COMPLETED) + with pytest.raises(InvalidStateError): + run.transition_to(IngestStatus.COMPLETED) + + +class TestCompletionCondition: + def test_not_complete_when_source_not_finished(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=False, + batches_ingested=3, + batches_completed=3, + ) + assert not run.is_complete + + def test_not_complete_when_batches_pending(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=2, + ) + assert not run.is_complete + + def test_complete_when_all_batches_done(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=3, + ) + assert run.is_complete + + def test_check_completion_transitions_status(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=2, + batches_completed=2, + ) + now = datetime.now(UTC) + completed = run.check_completion(now) + assert completed is True + assert run.status == IngestStatus.COMPLETED + assert run.completed_at == now + + def test_check_completion_noop_when_not_complete(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=2, + ) + completed = run.check_completion(datetime.now(UTC)) + assert completed is False + assert run.status == IngestStatus.RUNNING + + +class TestCounterIncrements: + def test_increment_batches_ingested(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.increment_batches_ingested() + assert run.batches_ingested == 1 + + def test_record_batch_completed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.record_batch_completed(published_count=42) + assert run.batches_completed == 1 + assert run.published_count == 42 + + def test_multiple_batch_completions(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.record_batch_completed(published_count=100) + run.record_batch_completed(published_count=50) + assert run.batches_completed == 2 + assert run.published_count == 150 + + def test_mark_ingestion_finished(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + assert not run.ingestion_finished + run.mark_ingestion_finished() + assert run.ingestion_finished + + def test_batch_size_default(self) -> None: + run = _make_run() + assert run.batch_size == 1000 + + def test_custom_batch_size(self) -> None: + run = _make_run(batch_size=500) + assert run.batch_size == 500 diff --git a/server/tests/unit/domain/ingest/test_ingest_service.py b/server/tests/unit/domain/ingest/test_ingest_service.py new file mode 100644 index 0000000..4742ec6 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingest_service.py @@ -0,0 +1,112 @@ +"""T022: Unit tests for IngestService.start_ingest.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.model.source import IngesterDefinition +from osa.domain.shared.model.srn import Domain + + +def _make_convention(*, has_ingester: bool = True): + conv = MagicMock() + conv.srn = "urn:osa:localhost:conv:test-conv@1.0.0" + conv.ingester = ( + IngesterDefinition( + image="ghcr.io/example/ingester:v1", + digest="sha256:abc123", + ) + if has_ingester + else None + ) + return conv + + +def _make_service( + *, + convention=None, + running_ingest=None, + convention_not_found: bool = False, +) -> IngestService: + ingest_repo = AsyncMock() + ingest_repo.get_running_for_convention.return_value = running_ingest + ingest_repo.save = AsyncMock() + + convention_service = AsyncMock() + if convention_not_found: + convention_service.get_convention.side_effect = NotFoundError("Convention not found") + else: + convention_service.get_convention.return_value = convention or _make_convention() + + outbox = AsyncMock() + + return IngestService( + ingest_repo=ingest_repo, + convention_service=convention_service, + outbox=outbox, + node_domain=Domain("localhost"), + ) + + +class TestStartIngest: + @pytest.mark.asyncio + async def test_creates_pending_ingest(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + assert run.status == IngestStatus.PENDING + assert run.convention_srn == "urn:osa:localhost:conv:test-conv@1.0.0" + assert run.batch_size == 1000 + + @pytest.mark.asyncio + async def test_saves_and_emits_event(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + service.ingest_repo.save.assert_called_once() + service.outbox.append.assert_called_once() + + # Verify the event is IngestStarted + event = service.outbox.append.call_args[0][0] + assert event.__class__.__name__ == "IngestStarted" + assert event.ingest_run_srn == run.srn + assert event.convention_srn == run.convention_srn + + @pytest.mark.asyncio + async def test_custom_batch_size(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + batch_size=500, + ) + assert run.batch_size == 500 + + @pytest.mark.asyncio + async def test_rejects_convention_not_found(self) -> None: + service = _make_service(convention_not_found=True) + with pytest.raises(NotFoundError): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:nonexistent@1.0.0", + ) + + @pytest.mark.asyncio + async def test_rejects_no_ingester_configured(self) -> None: + service = _make_service(convention=_make_convention(has_ingester=False)) + with pytest.raises(NotFoundError, match="No ingester configured"): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + + @pytest.mark.asyncio + async def test_rejects_ingest_already_running(self) -> None: + existing = MagicMock() + service = _make_service(running_ingest=existing) + with pytest.raises(ConflictError, match="already running"): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) diff --git a/server/tests/unit/domain/ingest/test_ingester_record.py b/server/tests/unit/domain/ingest/test_ingester_record.py new file mode 100644 index 0000000..59ca361 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingester_record.py @@ -0,0 +1,82 @@ +"""Tests for IngesterRecord model — from_jsonl parsing and IngesterFileRef.""" + +import json +from pathlib import Path + + +def test_from_jsonl_happy_path(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text( + json.dumps( + { + "source_id": "rec1", + "metadata": {"title": "Test"}, + "files": [{"name": "f.cif", "relative_path": "rec1/f.cif", "size_mb": 10.5}], + } + ) + + "\n" + + json.dumps({"source_id": "rec2", "metadata": {"title": "Test 2"}}) + + "\n" + ) + + records = IngesterRecord.from_jsonl(records_file) + assert len(records) == 2 + assert records[0].source_id == "rec1" + assert records[0].metadata == {"title": "Test"} + assert len(records[0].files) == 1 + assert records[0].files[0].name == "f.cif" + assert records[0].files[0].size_mb == 10.5 + assert records[1].source_id == "rec2" + assert records[1].files == [] + + +def test_from_jsonl_malformed_lines_skipped(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text( + "NOT VALID JSON\n" + json.dumps({"source_id": "good", "metadata": {}}) + "\n" + "{broken\n" + ) + + records = IngesterRecord.from_jsonl(records_file) + assert len(records) == 1 + assert records[0].source_id == "good" + + +def test_from_jsonl_empty_file(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text("") + records = IngesterRecord.from_jsonl(records_file) + assert records == [] + + +def test_from_jsonl_nonexistent_file(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records = IngesterRecord.from_jsonl(tmp_path / "does_not_exist.jsonl") + assert records == [] + + +def test_total_file_mb_property(): + from osa.domain.ingest.model.ingester_record import IngesterFileRef, IngesterRecord + + record = IngesterRecord( + source_id="rec1", + metadata={}, + files=[ + IngesterFileRef(name="a.cif", relative_path="rec1/a.cif", size_mb=10.0), + IngesterFileRef(name="b.cif", relative_path="rec1/b.cif", size_mb=28.5), + ], + ) + assert record.total_file_mb == 38.5 + + +def test_total_file_mb_empty(): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + record = IngesterRecord(source_id="rec1", metadata={}) + assert record.total_file_mb == 0 diff --git a/server/tests/unit/domain/record/test_record_service.py b/server/tests/unit/domain/record/test_record_service.py index 2400d82..9dff7c1 100644 --- a/server/tests/unit/domain/record/test_record_service.py +++ b/server/tests/unit/domain/record/test_record_service.py @@ -11,7 +11,7 @@ from osa.domain.record.service.record import RecordService from osa.domain.shared.model.source import ( DepositionSource, - HarvestSource, + IngestSource, ) from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN, Domain, LocalId from osa.domain.shared.outbox import Outbox @@ -124,24 +124,24 @@ async def test_publish_record_creates_version_1( assert record.srn.version.root == 1 -class TestRecordServiceHarvestSource: - """US2: Verify harvest-sourced records publish correctly.""" +class TestRecordServiceIngestSource: + """US2: Verify ingest-sourced records publish correctly.""" @pytest.mark.asyncio - async def test_publish_with_harvest_source( + async def test_publish_with_ingest_source( self, mock_record_repo: RecordRepository, mock_outbox: Outbox, node_domain: Domain, ): - """HarvestSource draft produces correct Record + RecordPublished event.""" + """IngestSource draft produces correct Record + RecordPublished event.""" draft = RecordDraft( - source=HarvestSource( + source=IngestSource( id="run-123-pdb-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ), - metadata={"title": "Harvested Protein"}, + metadata={"title": "Ingested Protein"}, convention_srn=_make_conv_srn(), expected_features=["pocket_detect"], ) @@ -155,12 +155,12 @@ async def test_publish_with_harvest_source( record = await service.publish_record(draft) - assert record.source.type == "harvest" + assert record.source.type == "ingest" assert record.source.upstream_source == "pdb" assert record.convention_srn == _make_conv_srn() mock_record_repo.save.assert_called_once() event = mock_outbox.append.call_args[0][0] assert isinstance(event, RecordPublished) - assert event.source.type == "harvest" + assert event.source.type == "ingest" assert event.expected_features == ["pocket_detect"] diff --git a/server/tests/unit/domain/shared/test_hook_models.py b/server/tests/unit/domain/shared/test_hook_models.py index 249d06f..cf669b0 100644 --- a/server/tests/unit/domain/shared/test_hook_models.py +++ b/server/tests/unit/domain/shared/test_hook_models.py @@ -92,7 +92,7 @@ def test_oci_limits_defaults(): limits = OciLimits() assert limits.timeout_seconds == 300 - assert limits.memory == "512m" + assert limits.memory == "1g" assert limits.cpu == "0.5" @@ -176,7 +176,7 @@ def test_hook_definition_default_limits(): feature=TableFeatureSpec(cardinality="one", columns=[]), ) assert hook_def.runtime.limits.timeout_seconds == 300 - assert hook_def.runtime.limits.memory == "512m" + assert hook_def.runtime.limits.memory == "1g" def test_hook_definition_serialization_roundtrip(): @@ -211,6 +211,60 @@ def test_hook_definition_serialization_roundtrip(): assert restored.feature.columns[1].required is False +class TestMemoryDoubling: + """Tests for HookDefinition.with_memory() and with_doubled_memory().""" + + def _make_hook(self, memory: str = "1g"): + from osa.domain.shared.model.hook import ( + HookDefinition, + OciConfig, + OciLimits, + TableFeatureSpec, + ) + + return HookDefinition( + name="detect_pockets", + runtime=OciConfig( + image="img:v1", + digest="sha256:abc", + limits=OciLimits(memory=memory), + ), + feature=TableFeatureSpec(cardinality="one", columns=[]), + ) + + def test_hook_definition_with_memory(self): + hook = self._make_hook("1g") + updated = hook.with_memory("2g") + assert updated.runtime.limits.memory == "2g" + # original unchanged (frozen) + assert hook.runtime.limits.memory == "1g" + + def test_hook_definition_with_doubled_memory_1g(self): + hook = self._make_hook("1g") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "2g" + + def test_hook_definition_with_doubled_memory_512m(self): + hook = self._make_hook("512m") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "1g" + + def test_hook_definition_with_doubled_memory_768m(self): + hook = self._make_hook("768m") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "1536m" + + def test_hook_definition_with_doubled_memory_preserves_other_fields(self): + hook = self._make_hook("1g") + doubled = hook.with_doubled_memory() + assert doubled.name == hook.name + assert doubled.runtime.image == hook.runtime.image + assert doubled.runtime.digest == hook.runtime.digest + assert doubled.runtime.limits.timeout_seconds == hook.runtime.limits.timeout_seconds + assert doubled.runtime.limits.cpu == hook.runtime.limits.cpu + assert doubled.feature == hook.feature + + class TestNameValidation: """Hook and column names must be safe PG identifiers.""" diff --git a/server/tests/unit/domain/shared/test_record_source.py b/server/tests/unit/domain/shared/test_record_source.py index ff71708..c1fe20d 100644 --- a/server/tests/unit/domain/shared/test_record_source.py +++ b/server/tests/unit/domain/shared/test_record_source.py @@ -5,7 +5,7 @@ from osa.domain.shared.model.source import ( DepositionSource, - HarvestSource, + IngestSource, RecordSource, ) @@ -27,31 +27,31 @@ def test_serialization_roundtrip(self): assert restored == src -class TestHarvestSource: - def test_type_is_harvest(self): - src = HarvestSource( +class TestIngestSource: + def test_type_is_ingest(self): + src = IngestSource( id="run-123-source-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ) - assert src.type == "harvest" + assert src.type == "ingest" - def test_requires_harvest_run_srn(self): + def test_requires_ingest_run_srn(self): with pytest.raises(ValidationError): - HarvestSource(id="run-123", upstream_source="pdb") + IngestSource(id="run-123", upstream_source="pdb") def test_requires_upstream_source(self): with pytest.raises(ValidationError): - HarvestSource(id="run-123", harvest_run_srn="urn:osa:localhost:val:run123") + IngestSource(id="run-123", ingest_run_srn="urn:osa:localhost:val:run123") def test_serialization_roundtrip(self): - src = HarvestSource( + src = IngestSource( id="run-123-source-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ) data = src.model_dump() - restored = HarvestSource.model_validate(data) + restored = IngestSource.model_validate(data) assert restored == src @@ -61,16 +61,16 @@ def test_deserializes_deposition(self): src = adapter.validate_python({"type": "deposition", "id": "dep-abc"}) assert isinstance(src, DepositionSource) - def test_deserializes_harvest(self): + def test_deserializes_ingest(self): data = { - "type": "harvest", + "type": "ingest", "id": "run-123", - "harvest_run_srn": "urn:osa:localhost:val:run1", + "ingest_run_srn": "urn:osa:localhost:val:run1", "upstream_source": "geo", } adapter = TypeAdapter(RecordSource) src = adapter.validate_python(data) - assert isinstance(src, HarvestSource) + assert isinstance(src, IngestSource) assert src.upstream_source == "geo" def test_rejects_unknown_type(self): @@ -81,12 +81,12 @@ def test_rejects_unknown_type(self): def test_json_roundtrip(self): """Serialize to JSON and back via the union type.""" adapter = TypeAdapter(RecordSource) - src = HarvestSource( + src = IngestSource( id="run-1", - harvest_run_srn="urn:osa:localhost:val:run1", + ingest_run_srn="urn:osa:localhost:val:run1", upstream_source="pdb", ) json_str = adapter.dump_json(src) restored = adapter.validate_json(json_str) - assert isinstance(restored, HarvestSource) + assert isinstance(restored, IngestSource) assert restored == src diff --git a/server/tests/unit/domain/source/test_source_record_ready.py b/server/tests/unit/domain/source/test_source_record_ready.py deleted file mode 100644 index 296d820..0000000 --- a/server/tests/unit/domain/source/test_source_record_ready.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Unit tests for SourceRecordReady event. - -Tests for User Story 3: Cross-domain decoupling — source domain. -""" - -from uuid import uuid4 - -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") - - -class TestSourceRecordReady: - def test_creation_with_all_fields(self): - """SourceRecordReady carries all required fields.""" - event = SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "4HHB", "title": "Hemoglobin"}, - file_paths=["4HHB/structure.cif"], - source_id="4HHB", - staging_dir="/tmp/staging/run-123", - ) - - assert event.convention_srn == _make_conv_srn() - assert event.metadata == {"pdb_id": "4HHB", "title": "Hemoglobin"} - assert event.file_paths == ["4HHB/structure.cif"] - assert event.source_id == "4HHB" - assert event.staging_dir == "/tmp/staging/run-123" - - def test_serialization_roundtrip(self): - """SourceRecordReady serializes and deserializes correctly.""" - event = SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "1CRN"}, - file_paths=["1CRN/data.cif", "1CRN/meta.json"], - source_id="1CRN", - staging_dir="/tmp/staging/run-456", - ) - - data = event.model_dump() - restored = SourceRecordReady.model_validate(data) - - assert restored.convention_srn == event.convention_srn - assert restored.metadata == event.metadata - assert restored.file_paths == event.file_paths - assert restored.source_id == event.source_id - assert restored.staging_dir == event.staging_dir - - def test_registered_in_event_registry(self): - """SourceRecordReady should be auto-registered in Event._registry.""" - from osa.domain.shared.event import Event - - assert "SourceRecordReady" in Event._registry diff --git a/server/tests/unit/domain/source/test_source_service.py b/server/tests/unit/domain/source/test_source_service.py deleted file mode 100644 index 46301dc..0000000 --- a/server/tests/unit/domain/source/test_source_service.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Unit tests for SourceService with OCI container model. - -Updated for cross-domain decoupling: SourceService no longer depends on -DepositionService, ConventionRepository, or FileStoragePort. Instead it -uses SourceStoragePort and emits SourceRecordReady events per record. -""" - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceOutput -from osa.domain.source.service.source import SourceService - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-conv-12345678@1.0.0") - - -def _make_source_def() -> SourceDefinition: - return SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - config={"batch_size": 100}, - ) - - -@pytest.fixture -def mock_outbox() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_source_storage() -> MagicMock: - storage = MagicMock() - storage.get_source_staging_dir.return_value = Path("/tmp/staging") - storage.get_source_output_dir.return_value = Path("/tmp/output") - return storage - - -@pytest.fixture -def mock_source_runner() -> AsyncMock: - runner = AsyncMock() - runner.run.return_value = SourceOutput( - records=[ - {"source_id": "4HHB", "metadata": {"pdb_id": "4HHB", "title": "Hemoglobin"}}, - {"source_id": "1CRN", "metadata": {"pdb_id": "1CRN", "title": "Crambin"}}, - ], - session=None, - files_dir=Path("/tmp/staging"), - ) - return runner - - -class TestSourceService: - @pytest.mark.asyncio - async def test_run_source_emits_per_record_events( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - result = await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - assert result.record_count == 2 - - # 2 SourceRecordReady + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - first = mock_outbox.append.call_args_list[0][0][0] - second = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(first, SourceRecordReady) - assert isinstance(second, SourceRecordReady) - assert first.source_id == "4HHB" - assert second.source_id == "1CRN" - - @pytest.mark.asyncio - async def test_run_source_carries_staging_dir( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRecordReady) - assert event.staging_dir == str(Path("/tmp/staging")) - - @pytest.mark.asyncio - async def test_run_source_emits_completion_event( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - last = mock_outbox.append.call_args_list[-1][0][0] - assert isinstance(last, SourceRunCompleted) - assert last.record_count == 2 - assert last.convention_srn == _make_conv_srn() - assert last.is_final_chunk is True - - @pytest.mark.asyncio - async def test_run_source_emits_continuation_when_session( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - mock_source_runner.run.return_value = SourceOutput( - records=[{"source_id": "4HHB", "metadata": {"pdb_id": "4HHB"}}], - session={"cursor": "abc"}, - files_dir=Path("/tmp/staging"), - ) - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - # 1 SourceRecordReady + 1 SourceRequested continuation + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - continuation = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(continuation, SourceRequested) - assert continuation.session == {"cursor": "abc"} - - @pytest.mark.asyncio - async def test_run_source_final_when_session_but_zero_records( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """Source returns session but zero records -> treated as final chunk.""" - mock_source_runner.run.return_value = SourceOutput( - records=[], - session={"cursor": "x"}, - files_dir=Path("/tmp/staging"), - ) - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - # Only SourceRunCompleted, no continuation - assert mock_outbox.append.call_count == 1 - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRunCompleted) - assert event.is_final_chunk is True diff --git a/server/tests/unit/domain/source/test_source_service_decoupled.py b/server/tests/unit/domain/source/test_source_service_decoupled.py deleted file mode 100644 index 8e55af8..0000000 --- a/server/tests/unit/domain/source/test_source_service_decoupled.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Unit tests for decoupled SourceService. - -Tests for User Story 3: Cross-domain decoupling. -Verifies SourceService emits SourceRecordReady per record -instead of calling DepositionService directly. -""" - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceOutput -from osa.domain.source.service.source import SourceService - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-conv-12345678@1.0.0") - - -def _make_source_def() -> SourceDefinition: - return SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - config={"batch_size": 100}, - ) - - -@pytest.fixture -def mock_outbox() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_source_storage() -> MagicMock: - storage = MagicMock() - storage.get_source_staging_dir.return_value = Path("/tmp/staging") - storage.get_source_output_dir.return_value = Path("/tmp/output") - return storage - - -@pytest.fixture -def mock_source_runner() -> AsyncMock: - runner = AsyncMock() - runner.run.return_value = SourceOutput( - records=[ - {"source_id": "4HHB", "metadata": {"pdb_id": "4HHB", "title": "Hemoglobin"}}, - {"source_id": "1CRN", "metadata": {"pdb_id": "1CRN", "title": "Crambin"}}, - ], - session=None, - files_dir=Path("/tmp/staging"), - ) - return runner - - -class TestDecoupledSourceService: - @pytest.mark.asyncio - async def test_emits_source_record_ready_per_record( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """SourceService emits SourceRecordReady for each record.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - # 2 SourceRecordReady + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - first = mock_outbox.append.call_args_list[0][0][0] - second = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(first, SourceRecordReady) - assert isinstance(second, SourceRecordReady) - assert first.source_id == "4HHB" - assert second.source_id == "1CRN" - - @pytest.mark.asyncio - async def test_source_record_ready_carries_staging_dir( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """SourceRecordReady carries the staging_dir path.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRecordReady) - assert event.staging_dir == str(Path("/tmp/staging")) - - @pytest.mark.asyncio - async def test_emits_completion_event( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """Still emits SourceRunCompleted after all records.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - last = mock_outbox.append.call_args_list[-1][0][0] - assert isinstance(last, SourceRunCompleted) - assert last.record_count == 2 - - @pytest.mark.asyncio - async def test_no_deposition_service_dependency(self): - """SourceService no longer depends on DepositionService.""" - import inspect - - sig = inspect.signature(SourceService.__init__) - param_names = list(sig.parameters.keys()) - assert "deposition_service" not in param_names - assert "convention_repo" not in param_names diff --git a/server/tests/unit/domain/source/test_trigger_initial_source_run.py b/server/tests/unit/domain/source/test_trigger_initial_source_run.py deleted file mode 100644 index 5791e62..0000000 --- a/server/tests/unit/domain/source/test_trigger_initial_source_run.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Unit tests for TriggerInitialSourceRun event handler. - -Tests for User Story 2: Convention Initialization Chain. -Verifies handler consumes ConventionReady (not ConventionRegistered). -""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.model.convention import Convention -from osa.domain.deposition.model.value import FileRequirements -from osa.domain.feature.event.convention_ready import ConventionReady -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import InitialRunConfig, SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.handler.trigger_initial_source_run import TriggerInitialSourceRun - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-deploy-conv@1.0.0") - - -def _make_convention(source: SourceDefinition | None = None) -> Convention: - return Convention( - srn=_make_conv_srn(), - title="Test Convention", - schema_srn=SchemaSRN.parse("urn:osa:localhost:schema:test-schema12345678@1.0.0"), - file_requirements=FileRequirements( - accepted_types=[".csv"], - min_count=1, - max_count=3, - max_file_size=1_000_000, - ), - source=source, - created_at=datetime.now(UTC), - ) - - -def _make_event() -> ConventionReady: - return ConventionReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - ) - - -class TestTriggerInitialSourceRun: - @pytest.mark.asyncio - async def test_emits_source_requested_when_initial_run_configured(self): - """Emits SourceRequested when convention has initial_run configured.""" - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=InitialRunConfig(limit=500), - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_called_once() - emitted = outbox.append.call_args[0][0] - assert isinstance(emitted, SourceRequested) - assert emitted.convention_srn == convention.srn - assert emitted.limit == 500 - - @pytest.mark.asyncio - async def test_no_event_when_source_has_no_initial_run(self): - """No SourceRequested when initial_run is None.""" - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=None, - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - @pytest.mark.asyncio - async def test_no_event_when_convention_has_no_source(self): - """No SourceRequested when convention has no source.""" - convention = _make_convention(source=None) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - def test_handler_event_type_is_convention_ready(self): - """TriggerInitialSourceRun.__event_type__ should be ConventionReady.""" - assert TriggerInitialSourceRun.__event_type__ is ConventionReady diff --git a/server/tests/unit/domain/source/test_trigger_source_on_deploy.py b/server/tests/unit/domain/source/test_trigger_source_on_deploy.py deleted file mode 100644 index 63b16b3..0000000 --- a/server/tests/unit/domain/source/test_trigger_source_on_deploy.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for TriggerSourceOnDeploy event handler.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.event.convention_registered import ConventionRegistered -from osa.domain.deposition.model.convention import Convention -from osa.domain.deposition.model.value import FileRequirements -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import InitialRunConfig, SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.handler.trigger_source_on_deploy import TriggerSourceOnDeploy - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-deploy-conv@1.0.0") - - -def _make_schema_srn() -> SchemaSRN: - return SchemaSRN.parse("urn:osa:localhost:schema:test-schema12345678@1.0.0") - - -def _make_file_reqs() -> FileRequirements: - return FileRequirements( - accepted_types=[".csv"], - min_count=1, - max_count=3, - max_file_size=1_000_000, - ) - - -def _make_convention( - source: SourceDefinition | None = None, -) -> Convention: - return Convention( - srn=_make_conv_srn(), - title="Test Convention", - schema_srn=_make_schema_srn(), - file_requirements=_make_file_reqs(), - source=source, - created_at=datetime.now(UTC), - ) - - -def _make_event() -> ConventionRegistered: - return ConventionRegistered( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - ) - - -class TestTriggerSourceOnDeploy: - @pytest.mark.asyncio - async def test_emits_source_requested_when_initial_run_configured(self): - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=InitialRunConfig(limit=500), - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_called_once() - emitted = outbox.append.call_args[0][0] - assert isinstance(emitted, SourceRequested) - assert emitted.convention_srn == convention.srn - assert emitted.limit == 500 - - @pytest.mark.asyncio - async def test_no_event_when_source_has_no_initial_run(self): - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=None, - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - @pytest.mark.asyncio - async def test_no_event_when_convention_has_no_source(self): - convention = _make_convention(source=None) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() diff --git a/server/tests/unit/domain/validation/test_hook_result.py b/server/tests/unit/domain/validation/test_hook_result.py index 8d2ba3d..8465ff5 100644 --- a/server/tests/unit/domain/validation/test_hook_result.py +++ b/server/tests/unit/domain/validation/test_hook_result.py @@ -105,6 +105,44 @@ def test_hook_result_default_progress_empty(): assert result.progress == [] +def test_hook_status_oom_value(): + from osa.domain.validation.model.hook_result import HookStatus + + assert HookStatus.OOM == "oom" + assert HookStatus.OOM.value == "oom" + + +def test_hook_result_oom_killed_true(): + from osa.domain.validation.model.hook_result import HookResult, HookStatus + + result = HookResult( + hook_name="detect_pockets", + status=HookStatus.OOM, + error_message="Hook killed by OOM (limit: 1g)", + duration_seconds=30.0, + ) + assert result.oom_killed is True + + +def test_hook_result_oom_killed_false(): + from osa.domain.validation.model.hook_result import HookResult, HookStatus + + result = HookResult( + hook_name="detect_pockets", + status=HookStatus.FAILED, + error_message="Some other error", + duration_seconds=10.0, + ) + assert result.oom_killed is False + + passed = HookResult( + hook_name="detect_pockets", + status=HookStatus.PASSED, + duration_seconds=5.0, + ) + assert passed.oom_killed is False + + def test_hook_result_serialization_roundtrip(): from osa.domain.validation.model.hook_result import ( HookResult, diff --git a/server/tests/unit/domain/validation/test_hook_runner.py b/server/tests/unit/domain/validation/test_hook_runner.py index fa1e1fa..5b7052c 100644 --- a/server/tests/unit/domain/validation/test_hook_runner.py +++ b/server/tests/unit/domain/validation/test_hook_runner.py @@ -6,55 +6,57 @@ from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs, HookRunner class TestHookInputs: def test_minimal_construction(self): inputs = HookInputs( - record_json={"srn": "urn:osa:localhost:rec:123"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", ) - assert inputs.record_json == {"srn": "urn:osa:localhost:rec:123"} - assert inputs.files_dir is None + assert inputs.records == [HookRecord(id="rec1", metadata={})] + assert inputs.files_dirs == {} assert inputs.config is None - def test_with_files_dir(self): - files = Path("/tmp/files") + def test_with_files_dirs(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", - files_dir=files, + files_dirs={"rec1": Path("/tmp/files")}, ) - assert inputs.files_dir == files + assert inputs.files_dirs == {"rec1": Path("/tmp/files")} def test_with_config(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", config={"r_min": 3.0, "threshold": 0.5}, ) assert inputs.config == {"r_min": 3.0, "threshold": 0.5} def test_full_construction(self): - files = Path("/tmp/data/files") inputs = HookInputs( - record_json={"srn": "urn:osa:localhost:rec:456", "name": "test"}, + records=[ + HookRecord(id="rec1", metadata={"name": "test"}), + HookRecord(id="rec2", metadata={"name": "test2"}), + ], run_id="localhost_test456", - files_dir=files, + files_dirs={"rec1": Path("/tmp/data/files/rec1")}, config={"key": "value"}, ) - assert inputs.record_json["name"] == "test" - assert inputs.files_dir == files + assert len(inputs.records) == 2 + assert inputs.records[0].metadata["name"] == "test" assert inputs.config == {"key": "value"} def test_is_frozen(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", ) with pytest.raises(AttributeError): - inputs.record_json = {} # type: ignore[misc] + inputs.records = [] # type: ignore[misc] def test_is_dataclass(self): """HookInputs is a frozen dataclass.""" diff --git a/server/tests/unit/domain/validation/test_hook_service.py b/server/tests/unit/domain/validation/test_hook_service.py new file mode 100644 index 0000000..db45e46 --- /dev/null +++ b/server/tests/unit/domain/validation/test_hook_service.py @@ -0,0 +1,567 @@ +"""Tests for HookService — OOM retry with checkpointing.""" + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.shared.model.hook import ( + ColumnDef, + HookDefinition, + OciConfig, + OciLimits, + TableFeatureSpec, +) +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs + + +def _make_hook(name: str = "detect_pockets", memory: str = "1g") -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig( + image="img:v1", + digest="sha256:abc", + limits=OciLimits(memory=memory), + ), + feature=TableFeatureSpec( + cardinality="one", + columns=[ColumnDef(name="score", json_type="number", required=True)], + ), + ) + + +def _inputs(records: list[HookRecord]) -> HookInputs: + return HookInputs(records=records, run_id="test-run") + + +def _make_records(count: int = 3) -> list[HookRecord]: + return [HookRecord(id=f"rec{i}", metadata={"title": f"Record {i}"}) for i in range(count)] + + +def _passed_result(hook_name: str = "detect_pockets", duration: float = 5.0) -> HookResult: + return HookResult(hook_name=hook_name, status=HookStatus.PASSED, duration_seconds=duration) + + +def _oom_result(hook_name: str = "detect_pockets", duration: float = 30.0) -> HookResult: + return HookResult( + hook_name=hook_name, + status=HookStatus.OOM, + error_message="Hook killed by OOM", + duration_seconds=duration, + ) + + +def _failed_result(hook_name: str = "detect_pockets", duration: float = 10.0) -> HookResult: + return HookResult( + hook_name=hook_name, + status=HookStatus.FAILED, + error_message="Some error", + duration_seconds=duration, + ) + + +class FakeHookStorage: + """Fake HookStoragePort for testing — stores checkpoints and outcomes in memory.""" + + def __init__(self) -> None: + self.checkpoints: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + self.written_outcomes: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + self._batch_outcomes: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + + def get_hook_output_dir(self, deposition_srn: Any, hook_name: str) -> Path: + return Path(f"/fake/hooks/{hook_name}") + + def get_files_dir(self, deposition_id: Any) -> Path: + return Path("/fake/files") + + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + self.checkpoints[str(work_dir)] = dict(outcomes) + + def write_batch_outcomes( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + self.written_outcomes[str(work_dir)] = dict(outcomes) + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + key = f"{output_dir}/{hook_name}" + return self._batch_outcomes.get(key, {}) + + def read_checkpoint(self, work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + return self.checkpoints.get(str(work_dir), {}) + + +class TestHookServiceNoOOM: + """T015: No OOM — hook runs once, correct output.""" + + @pytest.mark.asyncio + async def test_no_oom_runs_once(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + runner = AsyncMock() + runner.run.return_value = _passed_result() + storage = FakeHookStorage() + + # Simulate runner writing outcomes (features.jsonl) + # After run, HookService reads output dir for outcomes + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + import json + + features_file = output_dir / "features.jsonl" + features_file.write_text( + "\n".join(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) for r in records) + + "\n" + ) + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() + + +class TestHookServiceOOMRetry: + """T016: OOM retry doubles memory.""" + + @pytest.mark.asyncio + async def test_oom_retry_doubles_memory(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook(memory="1g") + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + call_count = 0 + + async def mock_run(h, inputs, wd): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: write partial output then OOM + features_file = output_dir / "features.jsonl" + features_file.write_text( + json.dumps({"id": records[0].id, "features": [{"score": 0.5}]}) + "\n" + ) + return _oom_result() + else: + # Second call: succeed with remaining + features_file = output_dir / "features.jsonl" + # Append the second record + with features_file.open("a") as f: + f.write(json.dumps({"id": records[1].id, "features": [{"score": 0.8}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + assert runner.run.call_count == 2 + # Second call should have doubled memory + second_call_hook = runner.run.call_args_list[1][0][0] + assert second_call_hook.runtime.limits.memory == "2g" + + +class TestHookServiceOOMExhaustion: + """T017: OOM exhaustion marks remaining records as errored.""" + + @pytest.mark.asyncio + async def test_oom_exhaustion_marks_errored(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook(memory="1g") + records = _make_records(1) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + runner = AsyncMock() + runner.run.return_value = _oom_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.OOM + # Should have retried MAX_OOM_RETRIES times + assert runner.run.call_count == 4 # 1 initial + 3 retries + + # Check outcomes written with error + assert str(work_dir) in storage.written_outcomes + outcomes = storage.written_outcomes[str(work_dir)] + assert HookRecordId("rec0") in outcomes + assert outcomes[HookRecordId("rec0")].status == OutcomeStatus.ERRORED + assert "OOM" in (outcomes[HookRecordId("rec0")].error or "") + + +class TestHookServiceNonOOMFailure: + """T018: Non-OOM failure does NOT trigger retry.""" + + @pytest.mark.asyncio + async def test_non_oom_failure_no_retry(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(1) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + (work_dir / "output").mkdir(parents=True) + + runner = AsyncMock() + runner.run.return_value = _failed_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.FAILED + runner.run.assert_called_once() + + +class TestHookServiceFinalize: + """T019: Finalize writes canonical files.""" + + @pytest.mark.asyncio + async def test_finalize_writes_canonical_files(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + features_file = output_dir / "features.jsonl" + features_file.write_text( + "\n".join(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) for r in records) + + "\n" + ) + + runner = AsyncMock() + runner.run.return_value = _passed_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert str(work_dir) in storage.written_outcomes + outcomes = storage.written_outcomes[str(work_dir)] + assert len(outcomes) == 2 + for r in records: + assert HookRecordId(r.id) in outcomes + assert outcomes[HookRecordId(r.id)].status == OutcomeStatus.PASSED + + +class TestHookServiceEmptyRecords: + """T020: Empty records list — no container launched.""" + + @pytest.mark.asyncio + async def test_empty_records_noop(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + runner = AsyncMock() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs([]), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_not_called() + + +class TestHookServiceMultiHook: + """T021: Multi-hook — second OOMs, first not re-run.""" + + @pytest.mark.asyncio + async def test_multi_hook_second_ooms(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook1 = _make_hook(name="hook_one") + hook2 = _make_hook(name="hook_two", memory="512m") + records = _make_records(1) + + work_dir1 = tmp_path / "hook_one" + work_dir1.mkdir() + (work_dir1 / "output").mkdir(parents=True) + work_dir2 = tmp_path / "hook_two" + work_dir2.mkdir() + (work_dir2 / "output").mkdir(parents=True) + + import json + + # Hook 1 succeeds + (work_dir1 / "output" / "features.jsonl").write_text( + json.dumps({"id": "rec0", "features": [{"score": 0.9}]}) + "\n" + ) + + runner = AsyncMock() + storage = FakeHookStorage() + + call_index = 0 + + async def side_effect(h, inputs, wd): + nonlocal call_index + call_index += 1 + if h.name == "hook_one": + return _passed_result(hook_name="hook_one") + else: + return _oom_result(hook_name="hook_two") + + runner.run.side_effect = side_effect + + service = HookService(hook_runner=runner, hook_storage=storage) + + # Run hook 1 — should pass + r1 = await service.run_hook(hook1, _inputs(records), work_dir1) + assert r1.status == HookStatus.PASSED + + # Run hook 2 — should OOM and exhaust retries + r2 = await service.run_hook(hook2, _inputs(records), work_dir2) + assert r2.status == HookStatus.OOM + + # Hook 1 was called once, hook 2 was called 4 times (1 + 3 retries) + hook1_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_one"] + hook2_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_two"] + assert len(hook1_calls) == 1 + assert len(hook2_calls) == 4 + + +class TestHookServiceCheckpointRecovery: + """T022: Crash recovery from checkpoint — skips completed records.""" + + @pytest.mark.asyncio + async def test_checkpoint_crash_recovery(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(3) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + # Pre-populate checkpoint: rec0 already done + import json + + checkpoint_file = work_dir / "_checkpoint.jsonl" + checkpoint_file.write_text( + json.dumps( + { + "record_id": "rec0", + "status": "passed", + "features": [{"score": 0.5}], + } + ) + + "\n" + ) + + runner = AsyncMock() + storage = FakeHookStorage() + + async def mock_run(h, inputs, wd): + # Should only receive rec1 and rec2, not rec0 + input_ids = [r.id for r in inputs.records] + assert "rec0" not in input_ids + # Write output for remaining + features_file = output_dir / "features.jsonl" + with features_file.open("a") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner.run.side_effect = mock_run + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() + # Final outcomes should contain all 3 records + outcomes = storage.written_outcomes[str(work_dir)] + assert len(outcomes) == 3 + + +class TestHookServiceCheckpointAllComplete: + """T023: All records in checkpoint — hook never called.""" + + @pytest.mark.asyncio + async def test_checkpoint_all_complete_skips_hook(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + import json + + checkpoint_file = work_dir / "_checkpoint.jsonl" + lines = [] + for r in records: + lines.append( + json.dumps({"record_id": r.id, "status": "passed", "features": [{"score": 0.9}]}) + ) + checkpoint_file.write_text("\n".join(lines) + "\n") + + runner = AsyncMock() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_not_called() + + +class TestHookServiceSorting: + """T034-T035: Records sorted by size_hint_mb ascending.""" + + @pytest.mark.asyncio + async def test_records_sorted_by_file_size(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + # Create records with different sizes — large first to test reordering + records = [ + HookRecord(id="large", metadata={}, size_hint_mb=100.0), + HookRecord(id="small", metadata={}, size_hint_mb=1.0), + HookRecord(id="medium", metadata={}, size_hint_mb=50.0), + ] + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + captured_order: list[str] = [] + + async def mock_run(h, inputs, wd): + for r in inputs.records: + captured_order.append(r.id) + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert captured_order == ["small", "medium", "large"] + + @pytest.mark.asyncio + async def test_sorting_skipped_when_no_sizes(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + # All records have default size_hint_mb=0 — original order preserved + records = [ + HookRecord(id="a", metadata={}), + HookRecord(id="b", metadata={}), + HookRecord(id="c", metadata={}), + ] + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + captured_order: list[str] = [] + + async def mock_run(h, inputs, wd): + for r in inputs.records: + captured_order.append(r.id) + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert captured_order == ["a", "b", "c"] + + +class TestHookServiceCorruptedCheckpoint: + """T024: Corrupted checkpoint treated as empty — all records reprocessed.""" + + @pytest.mark.asyncio + async def test_corrupted_checkpoint_treated_as_empty(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + # Write corrupted checkpoint + checkpoint_file = work_dir / "_checkpoint.jsonl" + checkpoint_file.write_text("NOT VALID JSON\n{also broken\n") + + import json + + runner = AsyncMock() + + async def mock_run(h, inputs, wd): + # Should receive ALL records since checkpoint is corrupted + assert len(inputs.records) == 2 + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index 1ce570e..cd2ffd7 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -14,6 +14,7 @@ from osa.domain.shared.model.srn import DepositionSRN, Domain from osa.domain.validation.model import RunStatus from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.domain.validation.service.validation import ValidationService @@ -63,9 +64,9 @@ def _make_service( def _make_inputs() -> HookInputs: return HookInputs( - record_json={"srn": "urn:osa:localhost:dep:test123", "metadata": {"name": "test"}}, + records=[HookRecord(id="urn:osa:localhost:dep:test123", metadata={"name": "test"})], run_id="localhost_test123", - files_dir=Path("/tmp/staging/files"), + files_dirs={"urn:osa:localhost:dep:test123": Path("/tmp/staging/files")}, ) @@ -195,3 +196,56 @@ async def run_hook(hook, inputs, output_dir): assert call_order == ["hook_a", "hook_b"] assert len(results) == 2 + + @pytest.mark.asyncio + async def test_validation_service_halts_on_oom(self): + """REGRESSION: OOM with exhausted retries should halt pipeline as FAILED.""" + hook_runner = AsyncMock() + hook_runner.run.return_value = _make_hook_result(status=HookStatus.OOM) + service = _make_service(hook_runner=hook_runner) + run = await service.create_run(inputs=_make_inputs()) + + run, results = await service.run_hooks( + run=run, + deposition_srn=_make_dep_srn(), + inputs=_make_inputs(), + hooks=[_make_hook_definition()], + ) + + assert run.status == RunStatus.FAILED + + @pytest.mark.asyncio + async def test_validation_service_retries_on_oom(self): + """OOM should be retried via HookService; PASSED on retry → COMPLETED.""" + call_count = 0 + + async def run_hook(hook, inputs, output_dir): + nonlocal call_count + call_count += 1 + if call_count == 1: + return HookResult( + hook_name=hook.name, + status=HookStatus.OOM, + error_message="OOM", + duration_seconds=30.0, + ) + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=5.0, + ) + + hook_runner = AsyncMock() + hook_runner.run.side_effect = run_hook + service = _make_service(hook_runner=hook_runner) + run = await service.create_run(inputs=_make_inputs()) + + run, results = await service.run_hooks( + run=run, + deposition_srn=_make_dep_srn(), + inputs=_make_inputs(), + hooks=[_make_hook_definition()], + ) + + assert run.status == RunStatus.COMPLETED + assert call_count == 2 diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 14a9b47..9e63c87 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -15,6 +15,7 @@ TableFeatureSpec, ) from osa.domain.validation.model.hook_result import HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.k8s.runner import K8sHookRunner @@ -460,7 +461,7 @@ async def test_successful_run(self, tmp_path: Path): b'{"step":"Check","status":"completed","message":"OK"}\n' ) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, core_api, @@ -509,7 +510,7 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -574,7 +575,7 @@ async def test_oom_exit_137(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -584,7 +585,7 @@ async def test_oom_exit_137(self, tmp_path: Path): work_dir, ) - assert result.status == HookStatus.FAILED + assert result.status == HookStatus.OOM assert "oom" in result.error_message.lower() @pytest.mark.asyncio @@ -633,7 +634,7 @@ async def test_nonzero_exit(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -687,7 +688,7 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -723,7 +724,7 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -778,7 +779,7 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -843,7 +844,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): runner._s3.get_object.return_value = ( b'{"step":"Validate","status":"rejected","message":"Missing atoms"}\n' ) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -904,7 +905,7 @@ async def test_run_uses_run_id_from_inputs(self, tmp_path: Path): hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="my-real-run-id", ) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py similarity index 84% rename from server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py rename to server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py index 39f19ae..398905e 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py @@ -1,4 +1,4 @@ -"""Unit tests for K8sSourceRunner — Job spec differences, source lifecycle.""" +"""Unit tests for K8sIngesterRunner — Job spec differences, ingester lifecycle.""" from pathlib import Path from typing import Any @@ -8,27 +8,27 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError -from osa.domain.shared.model.source import SourceDefinition, SourceLimits +from osa.domain.shared.model.source import IngesterDefinition, IngesterLimits from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.port.source_runner import SourceInputs -from osa.infrastructure.k8s.source_runner import K8sSourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs +from osa.infrastructure.k8s.ingester_runner import K8sIngesterRunner _CONV_SRN = ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") -def _make_source( - image: str = "ghcr.io/example/source:v1", +def _make_ingester( + image: str = "ghcr.io/example/ingester:v1", digest: str = "sha256:abc123", timeout: int = 3600, memory: str = "4g", cpu: str = "2.0", config: dict[str, Any] | None = None, -) -> SourceDefinition: - return SourceDefinition( +) -> IngesterDefinition: + return IngesterDefinition( image=image, digest=digest, config=config, - limits=SourceLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), + limits=IngesterLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), ) @@ -53,10 +53,10 @@ def _make_s3_mock() -> AsyncMock: return s3 -def _make_runner(config: K8sConfig | None = None) -> K8sSourceRunner: +def _make_runner(config: K8sConfig | None = None) -> K8sIngesterRunner: api_client = MagicMock() s3 = _make_s3_mock() - return K8sSourceRunner(api_client=api_client, config=config or _make_config(), s3=s3) + return K8sIngesterRunner(api_client=api_client, config=config or _make_config(), s3=s3) # --------------------------------------------------------------------------- @@ -68,9 +68,9 @@ class TestSourceJobSpec: def test_network_enabled(self): """Source Jobs have normal DNS policy (network access).""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -80,9 +80,9 @@ def test_network_enabled(self): def test_writable_rootfs(self): """Source containers do not have readOnlyRootFilesystem.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -92,23 +92,23 @@ def test_writable_rootfs(self): def test_higher_defaults(self): """Source Jobs use higher defaults (3600s, 4g).""" runner = _make_runner() - source = _make_source(timeout=3600, memory="4g") + ingester = _make_ingester(timeout=3600, memory="4g") spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) resources = spec.spec.template.spec.containers[0].resources assert resources.limits["memory"] == "4Gi" - # activeDeadlineSeconds = scheduling_timeout + source timeout + # activeDeadlineSeconds = scheduling_timeout + ingester timeout assert spec.spec.active_deadline_seconds == 120 + 3600 def test_three_volume_mounts(self): """Source Jobs have input, output, and files mounts.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -121,9 +121,9 @@ def test_three_volume_mounts(self): def test_files_mount_writable(self): """Source files mount is writable.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -133,12 +133,12 @@ def test_files_mount_writable(self): def test_env_vars(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn=_CONV_SRN, limit=100, offset=50), + inputs=IngesterInputs(convention_srn=_CONV_SRN, limit=100, offset=50), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} @@ -152,47 +152,47 @@ def test_since_env_var(self): from datetime import datetime, UTC runner = _make_runner() - source = _make_source() + ingester = _make_ingester() since = datetime(2026, 1, 1, tzinfo=UTC) spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn=_CONV_SRN, since=since), + inputs=IngesterInputs(convention_srn=_CONV_SRN, since=since), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} assert "OSA_SINCE" in env_dict - def test_source_role_label(self): + def test_ingester_role_label(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) labels = spec.spec.template.metadata.labels - assert labels["osa.io/role"] == "source" + assert labels["osa.io/role"] == "ingester" def test_human_readable_name(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), ) name = spec.metadata.name - assert name.startswith("osa-source-") + assert name.startswith("osa-ingester-") assert len(name) <= 63 def test_convention_srn_in_labels(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), @@ -210,7 +210,7 @@ class TestSourceLifecycle: @pytest.mark.asyncio async def test_successful_run_with_records(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -239,7 +239,7 @@ async def test_successful_run_with_records(self, tmp_path: Path): completed_job.status.failed = None batch_api.read_namespaced_job.return_value = completed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" files_dir = work_dir / "files" @@ -258,11 +258,11 @@ async def s3_get(key: str) -> bytes: runner._s3.get_object.side_effect = s3_get - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) result = await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -276,7 +276,7 @@ async def s3_get(key: str) -> bytes: @pytest.mark.asyncio async def test_timeout_raises_external_service_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -304,18 +304,18 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): failed_job.status.failed = 1 batch_api.read_namespaced_job.return_value = failed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Tt]imed out|[Dd]eadline"): await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -324,7 +324,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): @pytest.mark.asyncio async def test_oom_raises_external_service_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -364,18 +364,18 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): core_api.list_namespaced_pod.side_effect = [pod_list, oom_pod_list] - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Oo]OM"): await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -383,7 +383,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): # --------------------------------------------------------------------------- -# Identity threading from SourceInputs +# Identity threading from IngesterInputs # --------------------------------------------------------------------------- @@ -395,7 +395,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): from unittest.mock import patch config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -421,14 +421,14 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): completed_job.status.failed = None batch_api.read_namespaced_job.return_value = completed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "run1" output_dir = work_dir / "output" output_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs( + inputs = IngesterInputs( convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:my-conv@1.0.0") ) @@ -436,7 +436,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), ): - await runner.run(source, inputs, files_dir, work_dir) + await runner.run(ingester, inputs, files_dir, work_dir) # Verify convention_srn from inputs ends up in the Job labels call_args = batch_api.create_namespaced_job.call_args diff --git a/server/tests/unit/infrastructure/oci/__init__.py b/server/tests/unit/infrastructure/oci/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/infrastructure/test_hook_output_parsing.py b/server/tests/unit/infrastructure/test_hook_output_parsing.py new file mode 100644 index 0000000..f1d740e --- /dev/null +++ b/server/tests/unit/infrastructure/test_hook_output_parsing.py @@ -0,0 +1,132 @@ +"""T004: Unit tests for JSONL batch output parsing via FilesystemStorageAdapter.""" + +import json +from pathlib import Path + +import pytest + +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter + + +def _write_jsonl(path: Path, lines: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + for line in lines: + f.write(json.dumps(line) + "\n") + + +def _hook_output_dir(base: Path, hook_name: str = "validate_dna") -> Path: + """Create the expected directory structure: {base}/hooks/{hook_name}/output/""" + d = base / "hooks" / hook_name / "output" + d.mkdir(parents=True, exist_ok=True) + return d + + +@pytest.fixture +def adapter(tmp_path: Path) -> FilesystemStorageAdapter: + return FilesystemStorageAdapter(str(tmp_path)) + + +HOOK = "validate_dna" + + +class TestReadBatchOutcomes: + """Parse features.jsonl, rejections.jsonl, errors.jsonl via storage adapter.""" + + @pytest.mark.anyio + async def test_single_line_features( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "features.jsonl", [{"id": "rec1", "features": [{"score": 0.9}]}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 1 + assert outcomes["rec1"].status == "passed" + assert outcomes["rec1"].features == [{"score": 0.9}] + + @pytest.mark.anyio + async def test_multi_line_features( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl( + output / "features.jsonl", + [ + {"id": "rec1", "features": [{"score": 0.9}]}, + {"id": "rec2", "features": [{"score": 0.7}]}, + ], + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 2 + assert outcomes["rec1"].status == "passed" + assert outcomes["rec2"].status == "passed" + + @pytest.mark.anyio + async def test_rejections(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl( + output / "rejections.jsonl", [{"id": "rec3", "reason": "Missing required field"}] + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes["rec3"].status == "rejected" + assert outcomes["rec3"].reason == "Missing required field" + + @pytest.mark.anyio + async def test_errors(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "errors.jsonl", [{"id": "rec4", "error": "OOM", "retryable": True}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes["rec4"].status == "errored" + assert outcomes["rec4"].error == "OOM" + + @pytest.mark.anyio + async def test_mixed_outcomes(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "features.jsonl", [{"id": "a", "features": []}]) + _write_jsonl(output / "rejections.jsonl", [{"id": "b", "reason": "bad"}]) + _write_jsonl(output / "errors.jsonl", [{"id": "c", "error": "fail"}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 3 + assert outcomes["a"].status == "passed" + assert outcomes["b"].status == "rejected" + assert outcomes["c"].status == "errored" + + @pytest.mark.anyio + async def test_empty_directory(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + _hook_output_dir(tmp_path) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes == {} + + @pytest.mark.anyio + async def test_malformed_json_line_skipped( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + (output / "features.jsonl").write_text( + '{"id": "ok", "features": []}\n' + "not valid json\n" + '{"id": "also_ok", "features": [{"x": 1}]}\n' + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 2 + assert "ok" in outcomes + assert "also_ok" in outcomes + + @pytest.mark.anyio + async def test_missing_id_field_skipped( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + (output / "features.jsonl").write_text( + '{"id": "ok", "features": []}\n{"features": [{"x": 1}]}\n' + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 1 + assert "ok" in outcomes + + @pytest.mark.anyio + async def test_nonexistent_output_dir( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + outcomes = await adapter.read_batch_outcomes(str(tmp_path / "nonexistent"), HOOK) + assert outcomes == {} diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index ae43290..35352bb 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -13,6 +13,7 @@ TableFeatureSpec, ) from osa.domain.validation.model.hook_result import HookStatus, ProgressEntry +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.oci.runner import OciHookRunner from osa.infrastructure.runner_utils import ( @@ -199,7 +200,7 @@ async def test_successful_hook_returns_passed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -225,7 +226,7 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -238,7 +239,7 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): assert "exit" in (result.error_message or "").lower() @pytest.mark.asyncio - async def test_oom_killed_returns_failed(self, tmp_path: Path): + async def test_oom_killed_returns_oom(self, tmp_path: Path): docker = AsyncMock() container = AsyncMock() docker.containers.create.return_value = container @@ -248,7 +249,7 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -257,7 +258,7 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): result = await runner.run(hook, inputs, output_dir) - assert result.status == HookStatus.FAILED + assert result.status == HookStatus.OOM assert "oom" in (result.error_message or "").lower() @pytest.mark.asyncio @@ -278,7 +279,7 @@ async def hang(): runner = OciHookRunner(docker=docker) hook = _make_hook(timeout=1) # 1 second timeout inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -305,7 +306,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): work_dir.mkdir() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -335,7 +336,7 @@ async def test_security_hardening(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook(memory="4g", cpu="4.0") inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -369,7 +370,7 @@ async def test_env_vars_set(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -399,9 +400,9 @@ async def test_nested_bind_mounts(self, tmp_path: Path): files_dir = tmp_path / "files" files_dir.mkdir() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", - files_dir=files_dir, + files_dirs={"test": files_dir}, ) work_dir = tmp_path / "hook_work" @@ -413,7 +414,7 @@ async def test_nested_bind_mounts(self, tmp_path: Path): config = call_args[0][0] if call_args[0] else call_args[1].get("config", {}) binds = config["HostConfig"]["Binds"] - # Should have 3 binds: input:ro, output:rw, files:ro + # Should have 3 binds: input:ro, output:rw, files/{id}:ro assert len(binds) == 3 # input/ and output/ are sibling dirs under work_dir @@ -421,7 +422,7 @@ async def test_nested_bind_mounts(self, tmp_path: Path): out_bind = [b for b in binds if b.endswith(":/osa/out:rw")][0] assert str(work_dir / "input") in in_bind assert str(work_dir / "output") in out_bind - assert any(b.endswith(":/osa/in/files:ro") for b in binds) + assert any(":/osa/files/test:ro" in b for b in binds) @pytest.mark.asyncio async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): @@ -435,9 +436,8 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", - files_dir=None, ) output_dir = tmp_path / "output" @@ -449,7 +449,10 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): config = call_args[0][0] if call_args[0] else call_args[1].get("config", {}) binds = config["HostConfig"]["Binds"] - assert len(binds) == 2 # staging + output only + # staging + output + empty files base dir + assert len(binds) == 3 + # No per-record file mounts + assert not any(":/osa/files/" in b and ":ro" in b for b in binds if b.count("/") > 3) @pytest.mark.asyncio async def test_container_deleted_on_failure(self, tmp_path: Path): @@ -464,7 +467,7 @@ async def test_container_deleted_on_failure(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", )