From 64fdf6e2a35121cc5a25a6fe481a6424e5c3dc7d Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 25 Mar 2026 20:52:49 +0000 Subject: [PATCH 1/9] feat: add ingest domain for bulk ingestion pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the ingest domain that replaces the source domain with a batch-oriented pipeline: tap containers pull records in batches, hook containers validate/enrich them via a unified JSONL contract, and passing records are bulk-published with features. Key changes: - Unified hook contract: hooks process records.jsonl (batch of 1 for depositions, N for ingests) with JSONL outputs (features/rejections/ errors). HookInputs uses typed HookRecord instead of raw dicts. - Ingest pipeline: IngestRun aggregate with event-driven handlers (RunTap → RunHooks → PublishBatch) orchestrated via outbox workers. Atomic SQL counter increments for concurrent batch completion. - RecordService.bulk_publish() with ON CONFLICT DO NOTHING for duplicate detection across ingest runs. - Batch output parsing via FeatureStoragePort (not standalone functions) with both filesystem and S3 implementations. - Renamed source runner → tap runner (SourceRunner → TapRunner, SourceDefinition → TapDefinition) to avoid collision with RecordSource provenance types. - Deleted source domain entirely — ingest supersedes it. - API: POST /api/v1/ingestions to start an ingest run. Closes #104 --- .../migrations/versions/add_harvest_runs.py | 92 +++++++++ server/osa/application/api/rest/app.py | 2 + .../application/api/v1/routes/ingestions.py | 20 ++ server/osa/application/di.py | 4 +- .../deposition/command/create_convention.py | 6 +- .../osa/domain/deposition/handler/__init__.py | 5 +- .../handler/create_deposition_from_source.py | 50 ----- .../osa/domain/deposition/model/convention.py | 4 +- .../domain/deposition/query/get_convention.py | 6 +- .../domain/deposition/service/convention.py | 6 +- server/osa/domain/feature/handler/__init__.py | 3 +- .../feature/handler/insert_batch_features.py | 43 ++++ server/osa/domain/feature/port/storage.py | 14 ++ .../{source/port => ingest}/__init__.py | 0 .../domain/ingest/command}/__init__.py | 0 .../osa/domain/ingest/command/start_ingest.py | 43 ++++ server/osa/domain/ingest/event/__init__.py | 17 ++ server/osa/domain/ingest/event/events.py | 59 ++++++ server/osa/domain/ingest/handler/__init__.py | 7 + .../domain/ingest/handler/publish_batch.py | 187 ++++++++++++++++++ server/osa/domain/ingest/handler/run_hooks.py | 112 +++++++++++ .../osa/domain/ingest/handler/run_ingester.py | 129 ++++++++++++ server/osa/domain/ingest/model/__init__.py | 0 server/osa/domain/ingest/model/ingest_run.py | 85 ++++++++ server/osa/domain/ingest/port/__init__.py | 0 server/osa/domain/ingest/port/repository.py | 50 +++++ server/osa/domain/ingest/service/__init__.py | 0 server/osa/domain/ingest/service/ingest.py | 79 ++++++++ server/osa/domain/record/port/repository.py | 5 + server/osa/domain/record/service/record.py | 31 +++ server/osa/domain/shared/model/source.py | 32 +-- .../port/ingester_runner.py} | 28 +-- server/osa/domain/source/__init__.py | 1 - server/osa/domain/source/event/__init__.py | 7 - .../source/event/source_record_ready.py | 21 -- .../domain/source/event/source_requested.py | 28 --- .../source/event/source_run_completed.py | 17 -- server/osa/domain/source/handler/__init__.py | 6 - .../domain/source/handler/pull_from_source.py | 34 ---- .../handler/trigger_initial_source_run.py | 39 ---- .../handler/trigger_source_on_deploy.py | 35 ---- server/osa/domain/source/model/__init__.py | 1 - server/osa/domain/source/port/storage.py | 32 --- server/osa/domain/source/schedule/__init__.py | 5 - .../domain/source/schedule/source_schedule.py | 60 ------ server/osa/domain/source/service/__init__.py | 5 - server/osa/domain/source/service/source.py | 154 --------------- .../domain/validation/model/batch_outcome.py | 20 ++ .../osa/domain/validation/model/hook_input.py | 15 ++ .../osa/domain/validation/port/hook_runner.py | 14 +- .../domain/validation/service/validation.py | 13 +- server/osa/infrastructure/event/di.py | 23 +-- server/osa/infrastructure/event/worker.py | 31 +-- server/osa/infrastructure/ingest/__init__.py | 0 server/osa/infrastructure/ingest/di.py | 39 ++++ server/osa/infrastructure/k8s/di.py | 18 +- .../{source_runner.py => ingester_runner.py} | 30 +-- server/osa/infrastructure/k8s/runner.py | 12 +- server/osa/infrastructure/oci/di.py | 8 +- .../{source_runner.py => ingester_runner.py} | 22 +-- server/osa/infrastructure/oci/runner.py | 34 +++- .../persistence/adapter/storage.py | 43 ++++ server/osa/infrastructure/persistence/di.py | 5 - .../persistence/repository/convention.py | 6 +- .../persistence/repository/ingest.py | 127 ++++++++++++ .../persistence/repository/record.py | 27 ++- .../osa/infrastructure/persistence/tables.py | 24 ++- server/osa/infrastructure/s3/storage.py | 46 +++++ server/osa/infrastructure/source/__init__.py | 1 - server/osa/infrastructure/source/di.py | 27 --- .../persistence/test_convention_repo.py | 44 ++--- server/tests/unit/application/__init__.py | 0 server/tests/unit/application/api/__init__.py | 0 .../tests/unit/application/api/v1/__init__.py | 0 .../unit/application/test_app_factory.py | 38 ++-- server/tests/unit/domain/__init__.py | 0 .../deposition/test_convention_service_v2.py | 30 +-- .../test_create_deposition_from_source.py | 93 --------- .../feature/test_insert_record_features.py | 20 +- server/tests/unit/domain/ingest/__init__.py | 0 .../unit/domain/ingest/test_ingest_run.py | 144 ++++++++++++++ .../unit/domain/ingest/test_ingest_service.py | 112 +++++++++++ .../unit/domain/record/test_record_service.py | 20 +- .../unit/domain/shared/test_record_source.py | 38 ++-- .../domain/source/test_source_record_ready.py | 59 ------ .../unit/domain/source/test_source_service.py | 169 ---------------- .../source/test_source_service_decoupled.py | 131 ------------ .../source/test_trigger_initial_source_run.py | 118 ----------- .../source/test_trigger_source_on_deploy.py | 117 ----------- .../domain/validation/test_hook_runner.py | 34 ++-- .../validation/test_validation_service.py | 5 +- .../k8s/test_k8s_hook_runner.py | 19 +- ..._runner.py => test_k8s_ingester_runner.py} | 40 ++-- .../tests/unit/infrastructure/oci/__init__.py | 0 .../test_hook_output_parsing.py | 132 +++++++++++++ .../infrastructure/test_oci_hook_runner.py | 33 ++-- 96 files changed, 2024 insertions(+), 1521 deletions(-) create mode 100644 server/migrations/versions/add_harvest_runs.py create mode 100644 server/osa/application/api/v1/routes/ingestions.py delete mode 100644 server/osa/domain/deposition/handler/create_deposition_from_source.py create mode 100644 server/osa/domain/feature/handler/insert_batch_features.py rename server/osa/domain/{source/port => ingest}/__init__.py (100%) rename server/{tests/unit/domain/source => osa/domain/ingest/command}/__init__.py (100%) create mode 100644 server/osa/domain/ingest/command/start_ingest.py create mode 100644 server/osa/domain/ingest/event/__init__.py create mode 100644 server/osa/domain/ingest/event/events.py create mode 100644 server/osa/domain/ingest/handler/__init__.py create mode 100644 server/osa/domain/ingest/handler/publish_batch.py create mode 100644 server/osa/domain/ingest/handler/run_hooks.py create mode 100644 server/osa/domain/ingest/handler/run_ingester.py create mode 100644 server/osa/domain/ingest/model/__init__.py create mode 100644 server/osa/domain/ingest/model/ingest_run.py create mode 100644 server/osa/domain/ingest/port/__init__.py create mode 100644 server/osa/domain/ingest/port/repository.py create mode 100644 server/osa/domain/ingest/service/__init__.py create mode 100644 server/osa/domain/ingest/service/ingest.py rename server/osa/domain/{source/port/source_runner.py => shared/port/ingester_runner.py} (51%) delete mode 100644 server/osa/domain/source/__init__.py delete mode 100644 server/osa/domain/source/event/__init__.py delete mode 100644 server/osa/domain/source/event/source_record_ready.py delete mode 100644 server/osa/domain/source/event/source_requested.py delete mode 100644 server/osa/domain/source/event/source_run_completed.py delete mode 100644 server/osa/domain/source/handler/__init__.py delete mode 100644 server/osa/domain/source/handler/pull_from_source.py delete mode 100644 server/osa/domain/source/handler/trigger_initial_source_run.py delete mode 100644 server/osa/domain/source/handler/trigger_source_on_deploy.py delete mode 100644 server/osa/domain/source/model/__init__.py delete mode 100644 server/osa/domain/source/port/storage.py delete mode 100644 server/osa/domain/source/schedule/__init__.py delete mode 100644 server/osa/domain/source/schedule/source_schedule.py delete mode 100644 server/osa/domain/source/service/__init__.py delete mode 100644 server/osa/domain/source/service/source.py create mode 100644 server/osa/domain/validation/model/batch_outcome.py create mode 100644 server/osa/domain/validation/model/hook_input.py create mode 100644 server/osa/infrastructure/ingest/__init__.py create mode 100644 server/osa/infrastructure/ingest/di.py rename server/osa/infrastructure/k8s/{source_runner.py => ingester_runner.py} (96%) rename server/osa/infrastructure/oci/{source_runner.py => ingester_runner.py} (93%) create mode 100644 server/osa/infrastructure/persistence/repository/ingest.py delete mode 100644 server/osa/infrastructure/source/__init__.py delete mode 100644 server/osa/infrastructure/source/di.py create mode 100644 server/tests/unit/application/__init__.py create mode 100644 server/tests/unit/application/api/__init__.py create mode 100644 server/tests/unit/application/api/v1/__init__.py create mode 100644 server/tests/unit/domain/__init__.py delete mode 100644 server/tests/unit/domain/deposition/test_create_deposition_from_source.py create mode 100644 server/tests/unit/domain/ingest/__init__.py create mode 100644 server/tests/unit/domain/ingest/test_ingest_run.py create mode 100644 server/tests/unit/domain/ingest/test_ingest_service.py delete mode 100644 server/tests/unit/domain/source/test_source_record_ready.py delete mode 100644 server/tests/unit/domain/source/test_source_service.py delete mode 100644 server/tests/unit/domain/source/test_source_service_decoupled.py delete mode 100644 server/tests/unit/domain/source/test_trigger_initial_source_run.py delete mode 100644 server/tests/unit/domain/source/test_trigger_source_on_deploy.py rename server/tests/unit/infrastructure/k8s/{test_k8s_source_runner.py => test_k8s_ingester_runner.py} (91%) create mode 100644 server/tests/unit/infrastructure/oci/__init__.py create mode 100644 server/tests/unit/infrastructure/test_hook_output_parsing.py diff --git a/server/migrations/versions/add_harvest_runs.py b/server/migrations/versions/add_harvest_runs.py new file mode 100644 index 0000000..9c3564e --- /dev/null +++ b/server/migrations/versions/add_harvest_runs.py @@ -0,0 +1,92 @@ +"""add_ingest_runs + +Add ingest_runs table for bulk ingestion tracking. + +Revision ID: add_harvest_runs +Revises: source_agnostic_records +Create Date: 2026-03-25 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_harvest_runs" +down_revision: Union[str, Sequence[str], None] = "source_agnostic_records" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "ingest_runs", + sa.Column("srn", sa.String(), primary_key=True), + sa.Column( + "convention_srn", + sa.String(), + sa.ForeignKey("conventions.srn"), + nullable=False, + ), + sa.Column( + "status", + sa.String(32), + nullable=False, + server_default=sa.text("'pending'"), + ), + sa.Column( + "source_finished", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "batches_sourced", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "batches_completed", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "published_count", + sa.Integer(), + nullable=False, + server_default=sa.text("0"), + ), + sa.Column( + "batch_size", + sa.Integer(), + nullable=False, + server_default=sa.text("1000"), + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.CheckConstraint( + "status IN ('pending', 'running', 'completed', 'failed')", + name="ingest_runs_status_check", + ), + ) + + op.create_index( + "idx_ingest_runs_convention", + "ingest_runs", + ["convention_srn"], + ) + op.create_index( + "idx_ingest_runs_status", + "ingest_runs", + ["status"], + ) + + +def downgrade() -> None: + op.drop_index("idx_ingest_runs_status", table_name="ingest_runs") + op.drop_index("idx_ingest_runs_convention", table_name="ingest_runs") + op.drop_table("ingest_runs") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 4cbd273..3cfbc18 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -16,6 +16,7 @@ depositions, discovery, events, + ingestions, health, ontologies, records, @@ -117,6 +118,7 @@ def create_app( app_instance.include_router(schemas.router, prefix="/api/v1") app_instance.include_router(conventions.router, prefix="/api/v1") app_instance.include_router(depositions.router, prefix="/api/v1") + app_instance.include_router(ingestions.router, prefix="/api/v1") app_instance.include_router(validation.router, prefix="/api/v1") app_instance.include_router(discovery.router, prefix="/api/v1") diff --git a/server/osa/application/api/v1/routes/ingestions.py b/server/osa/application/api/v1/routes/ingestions.py new file mode 100644 index 0000000..86a6f50 --- /dev/null +++ b/server/osa/application/api/v1/routes/ingestions.py @@ -0,0 +1,20 @@ +"""Ingest REST routes.""" + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter + +from osa.domain.ingest.command.start_ingest import ( + IngestRunCreated, + StartIngest, + StartIngestHandler, +) + +router = APIRouter(prefix="/ingestions", tags=["Ingestions"], route_class=DishkaRoute) + + +@router.post("", response_model=IngestRunCreated, status_code=201) +async def start_ingest( + body: StartIngest, + handler: FromDishka[StartIngestHandler], +) -> IngestRunCreated: + return await handler.run(body) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 1a96763..f7e7ff3 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -17,7 +17,7 @@ from osa.infrastructure.index.di import IndexProvider from osa.infrastructure.k8s.di import RunnerProvider from osa.infrastructure.persistence import PersistenceProvider -from osa.infrastructure.source.di import SourceProvider +from osa.infrastructure.ingest.di import IngestProvider from osa.util.di.scope import Scope from osa.util.paths import OSAPaths @@ -44,7 +44,7 @@ def create_container( PersistenceProvider(), RunnerProvider(), IndexProvider(), - SourceProvider(), + IngestProvider(), EventProvider(extra_handlers=extra_handlers), HttpProvider(), DepositionProvider(), diff --git a/server/osa/domain/deposition/command/create_convention.py b/server/osa/domain/deposition/command/create_convention.py index 2e35dc8..c50059b 100644 --- a/server/osa/domain/deposition/command/create_convention.py +++ b/server/osa/domain/deposition/command/create_convention.py @@ -8,7 +8,7 @@ from osa.domain.shared.authorization.gate import public from osa.domain.shared.command import Command, CommandHandler, Result from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN @@ -21,7 +21,7 @@ class CreateConvention(Command): file_requirements: FileRequirements description: str | None = None hooks: list[HookDefinition] = [] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None class ConventionCreated(Result): @@ -44,7 +44,7 @@ async def run(self, cmd: CreateConvention) -> ConventionCreated: file_requirements=cmd.file_requirements, description=cmd.description, hooks=cmd.hooks, - source=cmd.source, + ingester=cmd.ingester, ) return ConventionCreated( srn=convention.srn, diff --git a/server/osa/domain/deposition/handler/__init__.py b/server/osa/domain/deposition/handler/__init__.py index b1218b2..16ec113 100644 --- a/server/osa/domain/deposition/handler/__init__.py +++ b/server/osa/domain/deposition/handler/__init__.py @@ -1,8 +1,5 @@ """Deposition domain event handlers.""" -from osa.domain.deposition.handler.create_deposition_from_source import ( - CreateDepositionFromSource, -) from osa.domain.deposition.handler.return_to_draft import ReturnToDraft -__all__ = ["CreateDepositionFromSource", "ReturnToDraft"] +__all__ = ["ReturnToDraft"] diff --git a/server/osa/domain/deposition/handler/create_deposition_from_source.py b/server/osa/domain/deposition/handler/create_deposition_from_source.py deleted file mode 100644 index 769e34a..0000000 --- a/server/osa/domain/deposition/handler/create_deposition_from_source.py +++ /dev/null @@ -1,50 +0,0 @@ -"""CreateDepositionFromSource — creates a deposition from a source record.""" - -import logging -from pathlib import Path - -from osa.domain.auth.model.value import SYSTEM_USER_ID -from osa.domain.deposition.port.storage import FileStoragePort -from osa.domain.deposition.service.deposition import DepositionService -from osa.domain.shared.event import EventHandler -from osa.domain.source.event.source_record_ready import SourceRecordReady - -logger = logging.getLogger(__name__) - - -class CreateDepositionFromSource(EventHandler[SourceRecordReady]): - """Creates a deposition when a source record is ready. - - Replaces the direct DepositionService calls that used to live - in SourceService — now the source domain only emits events and - the deposition domain reacts. - """ - - deposition_service: DepositionService - file_storage: FileStoragePort - - async def handle(self, event: SourceRecordReady) -> None: - """Create deposition, set metadata, move files, and submit.""" - dep = await self.deposition_service.create( - convention_srn=event.convention_srn, - owner_id=SYSTEM_USER_ID, - ) - - await self.deposition_service.update_metadata( - srn=dep.srn, - metadata=event.metadata, - ) - - await self.file_storage.move_source_files_to_deposition( - staging_dir=Path(event.staging_dir), - source_id=event.source_id, - deposition_srn=dep.srn, - ) - - await self.deposition_service.submit(srn=dep.srn) - - logger.info( - "Created deposition %s from source record %s", - dep.srn, - event.source_id, - ) diff --git a/server/osa/domain/deposition/model/convention.py b/server/osa/domain/deposition/model/convention.py index c1a2508..63c1e77 100644 --- a/server/osa/domain/deposition/model/convention.py +++ b/server/osa/domain/deposition/model/convention.py @@ -3,7 +3,7 @@ from osa.domain.deposition.model.value import FileRequirements from osa.domain.shared.model.aggregate import Aggregate from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN @@ -16,5 +16,5 @@ class Convention(Aggregate): schema_srn: SchemaSRN file_requirements: FileRequirements hooks: list[HookDefinition] = [] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None created_at: datetime diff --git a/server/osa/domain/deposition/query/get_convention.py b/server/osa/domain/deposition/query/get_convention.py index 27fcb4a..b39e467 100644 --- a/server/osa/domain/deposition/query/get_convention.py +++ b/server/osa/domain/deposition/query/get_convention.py @@ -4,7 +4,7 @@ from osa.domain.deposition.service.convention import ConventionService from osa.domain.shared.authorization.gate import public from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.domain.shared.query import Query, QueryHandler, Result @@ -20,7 +20,7 @@ class ConventionDetail(Result): schema_srn: SchemaSRN file_requirements: FileRequirements hooks: list[HookDefinition] - source: SourceDefinition | None = None + ingester: IngesterDefinition | None = None created_at: datetime @@ -37,6 +37,6 @@ async def run(self, cmd: GetConvention) -> ConventionDetail: schema_srn=conv.schema_srn, file_requirements=conv.file_requirements, hooks=conv.hooks, - source=conv.source, + ingester=conv.ingester, created_at=conv.created_at, ) diff --git a/server/osa/domain/deposition/service/convention.py b/server/osa/domain/deposition/service/convention.py index 3c2951d..79492e9 100644 --- a/server/osa/domain/deposition/service/convention.py +++ b/server/osa/domain/deposition/service/convention.py @@ -10,7 +10,7 @@ from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventId from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, Domain, LocalId, Semver from osa.domain.shared.outbox import Outbox from osa.domain.shared.service import Service @@ -30,7 +30,7 @@ async def create_convention( file_requirements: FileRequirements, description: str | None = None, hooks: list[HookDefinition] | None = None, - source: SourceDefinition | None = None, + ingester: IngesterDefinition | None = None, ) -> Convention: """Create a convention with an inline schema. @@ -59,7 +59,7 @@ async def create_convention( schema_srn=created_schema.srn, file_requirements=file_requirements, hooks=hooks or [], - source=source, + ingester=ingester, created_at=datetime.now(UTC), ) diff --git a/server/osa/domain/feature/handler/__init__.py b/server/osa/domain/feature/handler/__init__.py index b9db31d..542e540 100644 --- a/server/osa/domain/feature/handler/__init__.py +++ b/server/osa/domain/feature/handler/__init__.py @@ -1,6 +1,7 @@ """Feature domain event handlers.""" from osa.domain.feature.handler.create_feature_tables import CreateFeatureTables +from osa.domain.feature.handler.insert_batch_features import InsertBatchFeatures from osa.domain.feature.handler.insert_record_features import InsertRecordFeatures -__all__ = ["CreateFeatureTables", "InsertRecordFeatures"] +__all__ = ["CreateFeatureTables", "InsertBatchFeatures", "InsertRecordFeatures"] diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py new file mode 100644 index 0000000..e8785ea --- /dev/null +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -0,0 +1,43 @@ +"""InsertBatchFeatures — bulk feature insertion for ingest batches.""" + +import logging + +from osa.domain.feature.port.storage import FeatureStoragePort +from osa.domain.feature.service.feature import FeatureService +from osa.domain.ingest.event.events import IngestBatchPublished +from osa.domain.shared.event import EventHandler + +logger = logging.getLogger(__name__) + + +class InsertBatchFeatures(EventHandler[IngestBatchPublished]): + """Reads hook outputs for an ingest batch and inserts features in bulk. + + Handles IngestBatchPublished (batch-level event) rather than + per-record RecordPublished. Shares core insertion logic with + InsertRecordFeatures via FeatureService. + """ + + feature_service: FeatureService + feature_storage: FeatureStoragePort + + async def handle(self, event: IngestBatchPublished) -> None: + if not event.expected_features or not event.published_srns: + return + + hook_output_root = self.feature_storage.get_hook_output_root("ingest", event.ingest_run_srn) + + # Read batch outcomes for each hook and insert features + for record_srn in event.published_srns: + await self.feature_service.insert_features_for_record( + hook_output_dir=hook_output_root, + record_srn=record_srn, + expected_features=event.expected_features, + ) + + logger.info( + "Inserted features for %d records in batch %d of %s", + len(event.published_srns), + event.batch_index, + event.ingest_run_srn, + ) diff --git a/server/osa/domain/feature/port/storage.py b/server/osa/domain/feature/port/storage.py index 4590de5..4856c17 100644 --- a/server/osa/domain/feature/port/storage.py +++ b/server/osa/domain/feature/port/storage.py @@ -4,6 +4,7 @@ from typing import Any, Protocol from osa.domain.shared.port import Port +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome class FeatureStoragePort(Port, Protocol): @@ -29,3 +30,16 @@ async def read_hook_features( async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> bool: """Check whether features.json exists in a hook's output directory.""" ... + + @abstractmethod + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[str, BatchRecordOutcome]: + """Read JSONL batch outputs (features/rejections/errors) for a hook. + + Parses features.jsonl, rejections.jsonl, and errors.jsonl from the + hook's output directory. Each record appears in exactly one file. + + Returns a dict keyed by record ID. + """ + ... diff --git a/server/osa/domain/source/port/__init__.py b/server/osa/domain/ingest/__init__.py similarity index 100% rename from server/osa/domain/source/port/__init__.py rename to server/osa/domain/ingest/__init__.py diff --git a/server/tests/unit/domain/source/__init__.py b/server/osa/domain/ingest/command/__init__.py similarity index 100% rename from server/tests/unit/domain/source/__init__.py rename to server/osa/domain/ingest/command/__init__.py diff --git a/server/osa/domain/ingest/command/start_ingest.py b/server/osa/domain/ingest/command/start_ingest.py new file mode 100644 index 0000000..bef2fca --- /dev/null +++ b/server/osa/domain/ingest/command/start_ingest.py @@ -0,0 +1,43 @@ +"""StartIngest command — initiates a bulk ingestion run for a convention.""" + +from osa.domain.auth.model.role import Role +from osa.domain.shared.authorization.gate import at_least +from osa.domain.shared.command import Command, CommandHandler, Result + + +class StartIngest(Command): + """Start an ingest run for a convention.""" + + convention_srn: str + batch_size: int = 1000 + + +class IngestRunCreated(Result): + """Result of starting an ingest run.""" + + srn: str + convention_srn: str + status: str + started_at: str + + +class StartIngestHandler(CommandHandler[StartIngest, IngestRunCreated]): + """Thin command handler — delegates to IngestService.""" + + __auth__ = at_least(Role.ADMIN) + + from osa.domain.ingest.service.ingest import IngestService + + service: IngestService + + async def run(self, cmd: StartIngest) -> IngestRunCreated: + ingest_run = await self.service.start_ingest( + convention_srn=cmd.convention_srn, + batch_size=cmd.batch_size, + ) + return IngestRunCreated( + srn=ingest_run.srn, + convention_srn=ingest_run.convention_srn, + status=ingest_run.status, + started_at=ingest_run.started_at.isoformat(), + ) diff --git a/server/osa/domain/ingest/event/__init__.py b/server/osa/domain/ingest/event/__init__.py new file mode 100644 index 0000000..c8b3101 --- /dev/null +++ b/server/osa/domain/ingest/event/__init__.py @@ -0,0 +1,17 @@ +"""Ingest domain events.""" + +from osa.domain.ingest.event.events import ( + HookBatchCompleted, + IngestBatchPublished, + IngestCompleted, + IngestStarted, + IngesterBatchReady, +) + +__all__ = [ + "IngestStarted", + "IngesterBatchReady", + "HookBatchCompleted", + "IngestBatchPublished", + "IngestCompleted", +] diff --git a/server/osa/domain/ingest/event/events.py b/server/osa/domain/ingest/event/events.py new file mode 100644 index 0000000..f47f3fb --- /dev/null +++ b/server/osa/domain/ingest/event/events.py @@ -0,0 +1,59 @@ +"""Ingest domain events — payloads carry path references, not inline data (AD-1).""" + +from osa.domain.shared.event import Event, EventId + + +class IngestStarted(Event): + """Emitted when an ingest run is created. Triggers first ingester pull.""" + + id: EventId + ingest_run_srn: str + convention_srn: str + batch_size: int + + +class IngesterBatchReady(Event): + """Emitted when an ingester container produces a batch of records. + + Batch data is on disk at the path derived from {ingest_run_srn, batch_index}. + """ + + id: EventId + ingest_run_srn: str + batch_index: int + has_more: bool + + +class HookBatchCompleted(Event): + """Emitted when hook processing completes for a batch. + + Outcomes (features/rejections/errors) are on disk at the batch output path. + """ + + id: EventId + ingest_run_srn: str + batch_index: int + + +class IngestBatchPublished(Event): + """Emitted when records from a batch are bulk-published. + + Triggers InsertBatchFeatures for feature insertion. + Batch-level event — NOT per-record (AD-3). + """ + + id: EventId + ingest_run_srn: str + convention_srn: str + batch_index: int + published_srns: list[str] + published_count: int + expected_features: list[str] + + +class IngestCompleted(Event): + """Emitted when all batches are processed and the ingest run is complete.""" + + id: EventId + ingest_run_srn: str + total_published: int diff --git a/server/osa/domain/ingest/handler/__init__.py b/server/osa/domain/ingest/handler/__init__.py new file mode 100644 index 0000000..b294305 --- /dev/null +++ b/server/osa/domain/ingest/handler/__init__.py @@ -0,0 +1,7 @@ +"""Ingest domain event handlers.""" + +from osa.domain.ingest.handler.publish_batch import PublishBatch +from osa.domain.ingest.handler.run_hooks import RunHooks +from osa.domain.ingest.handler.run_ingester import RunIngester + +__all__ = ["RunIngester", "RunHooks", "PublishBatch"] diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py new file mode 100644 index 0000000..3d6637c --- /dev/null +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -0,0 +1,187 @@ +"""PublishBatch — reads hook outputs, bulk-publishes passing records.""" + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import ( + HookBatchCompleted, + IngestBatchPublished, + IngestCompleted, +) +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.record.model.draft import RecordDraft +from osa.domain.record.service import RecordService +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.source import IngestSource +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.domain.feature.port.storage import FeatureStoragePort +from osa.util.paths import OSAPaths + +logger = logging.getLogger(__name__) + + +class PublishBatch(EventHandler[HookBatchCompleted]): + """Reads hook outputs, constructs RecordDrafts, bulk-publishes passing records.""" + + ingest_repo: IngestRunRepository + convention_service: ConventionService + record_service: RecordService + feature_storage: FeatureStoragePort + outbox: Outbox + paths: OSAPaths + + async def handle(self, event: HookBatchCompleted) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(ingest_run.convention_srn) + ) + + # Read source records from batch dir + ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) + batch_dir = ingest_dir / "batches" / str(event.batch_index) + source_records = _read_source_records(batch_dir / "source" / "records.jsonl") + + # Read hook outcomes for all hooks + expected_features = [h.name for h in convention.hooks] + + # Determine which records passed all hooks + passed_records = _get_passed_records( + source_records=source_records, + batch_dir=batch_dir, + hooks=expected_features, + feature_storage=self.feature_storage, + ) + + if not passed_records: + logger.info("No passing records in batch %d", event.batch_index) + else: + # Construct RecordDrafts + drafts: list[RecordDraft] = [] + for record in passed_records: + source_id = record.get("source_id", record.get("id", "")) + drafts.append( + RecordDraft( + source=IngestSource( + id=f"{ingest_run.convention_srn}:{source_id}", + ingest_run_srn=ingest_run.srn, + upstream_source=source_id, + ), + metadata=record.get("metadata", {}), + convention_srn=ConventionSRN.parse(ingest_run.convention_srn), + expected_features=expected_features, + ) + ) + + # Bulk publish + published = await self.record_service.bulk_publish(drafts) + published_srns = [str(r.srn) for r in published] + published_count = len(published) + + logger.info( + "Published %d records from batch %d of %s", + published_count, + event.batch_index, + event.ingest_run_srn, + ) + + # Emit IngestBatchPublished for feature insertion + if published_count > 0: + await self.outbox.append( + IngestBatchPublished( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + convention_srn=ingest_run.convention_srn, + batch_index=event.batch_index, + published_srns=published_srns, + published_count=published_count, + expected_features=expected_features, + ) + ) + + # Update counters atomically — use published_count or 0 + count = len(passed_records) if passed_records else 0 + updated = await self.ingest_repo.increment_completed( + event.ingest_run_srn, + published_count=count, + ) + + # Check completion condition + if updated.is_complete and updated.status == IngestStatus.RUNNING: + updated.check_completion(datetime.now(UTC)) + await self.ingest_repo.save(updated) + + await self.outbox.append( + IngestCompleted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + total_published=updated.published_count, + ) + ) + logger.info( + "Ingest completed: %s (total published: %d)", + event.ingest_run_srn, + updated.published_count, + ) + + +def _read_source_records(records_file: Path) -> list[dict]: + """Read source records from JSONL file.""" + records: list[dict] = [] + if not records_file.exists(): + return records + for line in records_file.open(): + line = line.strip() + if not line: + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + logger.warning("Skipping malformed source record line") + return records + + +def _get_passed_records( + source_records: list[dict], + batch_dir: Path, + hooks: list[str], + feature_storage: FeatureStoragePort, +) -> list[dict]: + """Determine which records passed all hooks by reading features.jsonl.""" + if not hooks: + return source_records + + # For simplicity, read the last hook's features.jsonl to get passed IDs + # (hooks run sequentially, so the last hook's output has the final set) + last_hook = hooks[-1] + features_file = batch_dir / "hooks" / last_hook / "output" / "features.jsonl" + if not features_file.exists(): + return [] + + passed_ids: set[str] = set() + for line in features_file.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + record_id = data.get("id") + if record_id: + passed_ids.add(record_id) + except json.JSONDecodeError: + logger.warning("Skipping malformed features.jsonl line") + + return [r for r in source_records if r.get("source_id", r.get("id", "")) in passed_ids] + + +def _safe_srn(srn: str) -> str: + return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py new file mode 100644 index 0000000..bc143c4 --- /dev/null +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -0,0 +1,112 @@ +"""RunHooks — runs hook containers on an ingester batch.""" + +import json +import logging +from pathlib import Path +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import HookBatchCompleted, IngesterBatchReady +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.util.paths import OSAPaths + +logger = logging.getLogger(__name__) + + +class RunHooks(EventHandler[IngesterBatchReady]): + """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" + + __claim_timeout__ = 3600.0 # Hook runs can be long + + ingest_repo: IngestRunRepository + convention_service: ConventionService + hook_runner: HookRunner + outbox: Outbox + paths: OSAPaths + + async def handle(self, event: IngesterBatchReady) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(ingest_run.convention_srn) + ) + + # Read records from batch source dir + ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) + batch_dir = ingest_dir / "batches" / str(event.batch_index) + source_dir = batch_dir / "source" + records_file = source_dir / "records.jsonl" + + records: list[dict] = [] + if records_file.exists(): + for line in records_file.open(): + line = line.strip() + if line: + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + logger.warning( + "Skipping malformed record line in batch %d", event.batch_index + ) + + if not records: + logger.warning("No records in batch %d for %s", event.batch_index, event.ingest_run_srn) + + # Build files_dirs from source files + files_base = source_dir / "files" + files_dirs: dict[str, Path] = {} + if files_base.exists(): + for record in records: + record_id = record.get("source_id", record.get("id", "")) + record_files = files_base / str(record_id) + if record_files.exists(): + files_dirs[str(record_id)] = record_files + + # Run each hook sequentially + for hook in convention.hooks: + hook_output_dir = batch_dir / "hooks" / hook.name + hook_output_dir.mkdir(parents=True, exist_ok=True) + + inputs = HookInputs( + records=[ + HookRecord( + id=r.get("source_id", r.get("id", "")), + metadata=r.get("metadata", {}), + ) + for r in records + ], + run_id=f"{_safe_srn(event.ingest_run_srn)}_batch{event.batch_index}", + files_dirs=files_dirs, + config=None, + ) + + await self.hook_runner.run(hook, inputs, hook_output_dir) + + # Emit HookBatchCompleted + await self.outbox.append( + HookBatchCompleted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + batch_index=event.batch_index, + ) + ) + + logger.info( + "Hooks completed for batch %d of %s (%d records)", + event.batch_index, + event.ingest_run_srn, + len(records), + ) + + +def _safe_srn(srn: str) -> str: + """Convert SRN to filesystem-safe string.""" + return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py new file mode 100644 index 0000000..111dd21 --- /dev/null +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -0,0 +1,129 @@ +"""RunIngester — runs ingester container on IngestStarted or continuation.""" + +import json +import logging +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import IngestStarted, IngesterBatchReady +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import NotFoundError +from osa.domain.shared.event import EventHandler, EventId +from osa.domain.shared.model.srn import ConventionSRN +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterRunner +from osa.util.paths import OSAPaths + +logger = logging.getLogger(__name__) + + +class RunIngester(EventHandler[IngestStarted]): + """Runs ingester container and emits IngesterBatchReady per batch.""" + + __claim_timeout__ = 3600.0 # Ingester runs can be long + + ingest_repo: IngestRunRepository + convention_service: ConventionService + ingester_runner: IngesterRunner + outbox: Outbox + paths: OSAPaths + + async def handle(self, event: IngestStarted) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + if ingest_run is None: + raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + + # Transition to RUNNING on first ingester pull + if ingest_run.status == IngestStatus.PENDING: + ingest_run.mark_running() + await self.ingest_repo.save(ingest_run) + + convention = await self.convention_service.get_convention( + ConventionSRN.parse(event.convention_srn) + ) + if convention.ingester is None: + raise NotFoundError(f"No ingester for convention {event.convention_srn}") + + # Determine batch index from current batches_sourced + batch_index = ingest_run.batches_sourced + + # Prepare scratch directory + ingest_dir = self.paths.data_dir / "ingests" / self._safe_srn(event.ingest_run_srn) + batch_dir = ingest_dir / "batches" / str(batch_index) / "source" + batch_dir.mkdir(parents=True, exist_ok=True) + + # Load session state for continuation + session_file = ingest_dir / "session.json" + session = None + if session_file.exists(): + session = json.loads(session_file.read_text()) + + # Run ingester container + inputs = IngesterInputs( + convention_srn=convention.srn, + config=convention.ingester.config, + limit=ingest_run.batch_size, + session=session, + ) + files_dir = batch_dir / "files" + files_dir.mkdir(parents=True, exist_ok=True) + + output = await self.ingester_runner.run( + source=convention.ingester, + inputs=inputs, + files_dir=files_dir, + work_dir=batch_dir, + ) + + # Write records.jsonl to batch source dir + records_file = batch_dir / "records.jsonl" + with records_file.open("w") as f: + for record in output.records: + f.write(json.dumps(record) + "\n") + + # Save session for continuation + if output.session: + session_file.write_text(json.dumps(output.session)) + + has_more = output.session is not None and len(output.records) > 0 + + # Update counters atomically + await self.ingest_repo.increment_batches_sourced( + event.ingest_run_srn, + set_source_finished=not has_more, + ) + + # Emit batch ready event + await self.outbox.append( + IngesterBatchReady( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + batch_index=batch_index, + has_more=has_more, + ) + ) + + logger.info( + "Ingester batch %d ready for %s (%d records, has_more=%s)", + batch_index, + event.ingest_run_srn, + len(output.records), + has_more, + ) + + # Emit continuation event for next batch + if has_more: + await self.outbox.append( + IngestStarted( + id=EventId(uuid4()), + ingest_run_srn=event.ingest_run_srn, + convention_srn=event.convention_srn, + batch_size=ingest_run.batch_size, + ) + ) + + @staticmethod + def _safe_srn(srn: str) -> str: + """Convert SRN to filesystem-safe string.""" + return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/domain/ingest/model/__init__.py b/server/osa/domain/ingest/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/model/ingest_run.py b/server/osa/domain/ingest/model/ingest_run.py new file mode 100644 index 0000000..33e7f6f --- /dev/null +++ b/server/osa/domain/ingest/model/ingest_run.py @@ -0,0 +1,85 @@ +"""IngestRun aggregate — lean summary tracking a bulk ingestion execution.""" + +from datetime import datetime +from enum import StrEnum + +from osa.domain.shared.error import InvalidStateError +from osa.domain.shared.model.aggregate import Aggregate + + +class IngestStatus(StrEnum): + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +_VALID_TRANSITIONS: dict[IngestStatus, set[IngestStatus]] = { + IngestStatus.PENDING: {IngestStatus.RUNNING, IngestStatus.FAILED}, + IngestStatus.RUNNING: {IngestStatus.COMPLETED, IngestStatus.FAILED}, + IngestStatus.COMPLETED: set(), + IngestStatus.FAILED: set(), +} + + +class IngestRun(Aggregate): + """Lean summary aggregate tracking a bulk ingestion execution. + + No per-record data — batch output directories on disk are the audit trail. + Counter updates use atomic SQL increments in the repository. + """ + + srn: str + convention_srn: str + status: IngestStatus = IngestStatus.PENDING + source_finished: bool = False + batches_sourced: int = 0 + batches_completed: int = 0 + published_count: int = 0 + batch_size: int = 1000 + started_at: datetime + completed_at: datetime | None = None + + def transition_to(self, new_status: IngestStatus) -> None: + """Transition to a new status, enforcing valid transitions.""" + if new_status not in _VALID_TRANSITIONS[self.status]: + raise InvalidStateError(f"Cannot transition from {self.status} to {new_status}") + self.status = new_status + + def mark_running(self) -> None: + self.transition_to(IngestStatus.RUNNING) + + def mark_failed(self, completed_at: datetime) -> None: + self.transition_to(IngestStatus.FAILED) + self.completed_at = completed_at + + def mark_source_finished(self) -> None: + self.source_finished = True + + def increment_batches_sourced(self) -> None: + self.batches_sourced += 1 + + def record_batch_completed(self, published_count: int) -> None: + """Record a completed batch with its published count. + + In production, counter updates use atomic SQL increments — + this method is for in-memory aggregate state only. + """ + self.batches_completed += 1 + self.published_count += published_count + + @property + def is_complete(self) -> bool: + """Check the completion condition: all sourced batches are completed.""" + return self.source_finished and self.batches_sourced == self.batches_completed + + def check_completion(self, completed_at: datetime) -> bool: + """Check completion condition and transition if met. + + Returns True if the ingest run is now complete. + """ + if self.is_complete and self.status == IngestStatus.RUNNING: + self.transition_to(IngestStatus.COMPLETED) + self.completed_at = completed_at + return True + return False diff --git a/server/osa/domain/ingest/port/__init__.py b/server/osa/domain/ingest/port/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/port/repository.py b/server/osa/domain/ingest/port/repository.py new file mode 100644 index 0000000..ecbfbd0 --- /dev/null +++ b/server/osa/domain/ingest/port/repository.py @@ -0,0 +1,50 @@ +"""IngestRunRepository port — persistence interface for ingest runs.""" + +from abc import abstractmethod +from typing import Protocol + +from osa.domain.ingest.model.ingest_run import IngestRun +from osa.domain.shared.port import Port + + +class IngestRunRepository(Port, Protocol): + """Persistence interface for IngestRun aggregates. + + Counter updates (batches_completed, published_count) use atomic SQL + increments in the concrete implementation to avoid lost updates under + concurrent PublishBatch workers. + """ + + @abstractmethod + async def save(self, ingest_run: IngestRun) -> None: + """Persist an ingest run (insert or update).""" + ... + + @abstractmethod + async def get(self, srn: str) -> IngestRun | None: + """Get an ingest run by SRN.""" + ... + + @abstractmethod + async def get_running_for_convention(self, convention_srn: str) -> IngestRun | None: + """Get a running ingest run for a convention, if any.""" + ... + + @abstractmethod + async def increment_batches_sourced( + self, srn: str, *, set_source_finished: bool = False + ) -> IngestRun: + """Atomically increment batches_sourced and optionally set source_finished. + + Returns the updated IngestRun with DB-authoritative counter values. + """ + ... + + @abstractmethod + async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + """Atomically increment batches_completed and published_count. + + Returns the updated IngestRun with DB-authoritative counter values + for completion condition checking. + """ + ... diff --git a/server/osa/domain/ingest/service/__init__.py b/server/osa/domain/ingest/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/ingest/service/ingest.py b/server/osa/domain/ingest/service/ingest.py new file mode 100644 index 0000000..89c032c --- /dev/null +++ b/server/osa/domain/ingest/service/ingest.py @@ -0,0 +1,79 @@ +"""IngestService — orchestrates ingest lifecycle.""" + +import logging +from datetime import UTC, datetime +from uuid import uuid4 + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.event.events import IngestStarted +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.event import EventId +from osa.domain.shared.model.srn import ConventionSRN, Domain +from osa.domain.shared.outbox import Outbox +from osa.domain.shared.service import Service + +logger = logging.getLogger(__name__) + + +class IngestService(Service): + """Orchestrates ingest run creation and lifecycle.""" + + ingest_repo: IngestRunRepository + convention_service: ConventionService + outbox: Outbox + node_domain: Domain + + async def start_ingest( + self, + convention_srn: str, + batch_size: int = 1000, + ) -> IngestRun: + """Create an ingest run for a convention. + + Validates: + - Convention exists + - Convention has an ingester configured + - No ingest is already running for this convention + """ + parsed_srn = ConventionSRN.parse(convention_srn) + convention = await self.convention_service.get_convention(parsed_srn) + + if convention.ingester is None: + raise NotFoundError( + f"No ingester configured for convention {convention_srn}", + code="no_ingester_configured", + ) + + existing = await self.ingest_repo.get_running_for_convention(convention_srn) + if existing is not None: + raise ConflictError( + f"Ingest already running for convention {convention_srn}", + code="ingest_already_running", + ) + + srn = f"urn:osa:{self.node_domain.root}:ing:{uuid4()}" + now = datetime.now(UTC) + + ingest_run = IngestRun( + srn=srn, + convention_srn=convention_srn, + status=IngestStatus.PENDING, + batch_size=batch_size, + started_at=now, + ) + + await self.ingest_repo.save(ingest_run) + + await self.outbox.append( + IngestStarted( + id=EventId(uuid4()), + ingest_run_srn=srn, + convention_srn=convention_srn, + batch_size=batch_size, + ) + ) + + logger.info("Ingest started: %s for convention %s", srn, convention_srn) + return ingest_run diff --git a/server/osa/domain/record/port/repository.py b/server/osa/domain/record/port/repository.py index 34a1724..345d6e2 100644 --- a/server/osa/domain/record/port/repository.py +++ b/server/osa/domain/record/port/repository.py @@ -12,6 +12,11 @@ class RecordRepository(Port, Protocol): @abstractmethod async def save(self, record: Record) -> None: ... + @abstractmethod + async def save_many(self, records: list[Record]) -> list[Record]: + """Multi-row INSERT with ON CONFLICT DO NOTHING. Returns inserted records.""" + ... + @abstractmethod async def get(self, srn: RecordSRN) -> Record | None: ... diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index d6d3cc4..844045b 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -49,6 +49,37 @@ async def get(self, srn: RecordSRN) -> Record: raise NotFoundError(f"Record not found: {srn}") return record + async def bulk_publish(self, drafts: list[RecordDraft]) -> list[Record]: + """Bulk-publish records from an ingest batch. + + Uses save_many() for multi-row INSERT with ON CONFLICT DO NOTHING. + Does NOT emit per-record RecordPublished events — the caller emits + a single IngestBatchPublished event instead (AD-3). + """ + if not drafts: + return [] + + records: list[Record] = [] + for draft in drafts: + record_srn = RecordSRN( + domain=self.node_domain, + id=LocalId(str(uuid4())), + version=RecordVersion(1), + ) + records.append( + Record( + srn=record_srn, + source=draft.source, + convention_srn=draft.convention_srn, + metadata=draft.metadata, + published_at=datetime.now(UTC), + ) + ) + + published = await self.record_repo.save_many(records) + logger.info("Bulk-published %d records (of %d drafts)", len(published), len(drafts)) + return published + async def publish_record(self, draft: RecordDraft) -> Record: """Create and persist a Record from a draft.""" logger.info(f"Creating record from {draft.source.type} source: {draft.source.id}") diff --git a/server/osa/domain/shared/model/source.py b/server/osa/domain/shared/model/source.py index e61fd8e..1cd0d6c 100644 --- a/server/osa/domain/shared/model/source.py +++ b/server/osa/domain/shared/model/source.py @@ -1,4 +1,4 @@ -"""Shared source domain models used across deposition and source domains.""" +"""Shared source domain models used across deposition and ingest domains.""" from typing import Annotated, Any, Literal, Union @@ -7,23 +7,23 @@ from osa.domain.shared.model.value import ValueObject -class SourceLimits(ValueObject): - """Resource limits for source container execution.""" +class IngesterLimits(ValueObject): + """Resource limits for ingester container execution.""" timeout_seconds: int = 3600 memory: str = "512m" cpu: str = "0.25" -class SourceScheduleConfig(ValueObject): - """Cron schedule for periodic source runs.""" +class IngesterScheduleConfig(ValueObject): + """Cron schedule for periodic ingester runs.""" cron: str limit: int | None = None class InitialRunConfig(ValueObject): - """Configuration for the first source run on server startup.""" + """Configuration for the first ingester run on server startup.""" limit: int | None = None @@ -51,11 +51,11 @@ class DepositionSource(_RecordSourceBase): type: Literal["deposition"] = "deposition" -class HarvestSource(_RecordSourceBase): - """Record originated from an automated harvest run.""" +class IngestSource(_RecordSourceBase): + """Record originated from an automated ingest run.""" - type: Literal["harvest"] = "harvest" - harvest_run_srn: str + type: Literal["ingest"] = "ingest" + ingest_run_srn: str upstream_source: str @@ -68,21 +68,21 @@ def _record_source_discriminator(v: Any) -> str: RecordSource = Annotated[ Union[ Annotated[DepositionSource, Tag("deposition")], - Annotated[HarvestSource, Tag("harvest")], + Annotated[IngestSource, Tag("ingest")], ], Discriminator(_record_source_discriminator), ] -# ── Source runner definitions ── +# ── Ingester runner definitions ── -class SourceDefinition(ValueObject): - """Complete specification for a source: image reference + config + limits.""" +class IngesterDefinition(ValueObject): + """Complete specification for an ingester: image reference + config + limits.""" image: str digest: str config: dict[str, Any] | None = None - limits: SourceLimits = Field(default_factory=SourceLimits) - schedule: SourceScheduleConfig | None = None + limits: IngesterLimits = Field(default_factory=IngesterLimits) + schedule: IngesterScheduleConfig | None = None initial_run: InitialRunConfig | None = None diff --git a/server/osa/domain/source/port/source_runner.py b/server/osa/domain/shared/port/ingester_runner.py similarity index 51% rename from server/osa/domain/source/port/source_runner.py rename to server/osa/domain/shared/port/ingester_runner.py index 44ad49f..4a54742 100644 --- a/server/osa/domain/source/port/source_runner.py +++ b/server/osa/domain/shared/port/ingester_runner.py @@ -1,4 +1,8 @@ -"""SourceRunner port — interface for executing source containers.""" +"""IngesterRunner port — interface for executing ingester containers. + +Relocated from domain/source/ to shared/port/ since both the ingest +domain and infrastructure runners depend on this contract. +""" from __future__ import annotations @@ -7,13 +11,13 @@ from pathlib import Path from typing import Any, Protocol -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN @dataclass(frozen=True) -class SourceInputs: - """Inputs for a source container run.""" +class IngesterInputs: + """Inputs for an ingester container run.""" convention_srn: ConventionSRN config: dict[str, Any] | None = None @@ -24,21 +28,21 @@ class SourceInputs: @dataclass(frozen=True) -class SourceOutput: - """Output from a source container run.""" +class IngesterOutput: + """Output from an ingester container run.""" records: list[dict[str, Any]] # Parsed from records.jsonl session: dict[str, Any] | None # From session.json (continuation) - files_dir: Path # Where source wrote data files + files_dir: Path # Where ingester wrote data files -class SourceRunner(Protocol): - """Protocol for executing source containers.""" +class IngesterRunner(Protocol): + """Protocol for executing ingester containers.""" async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: ... + ) -> IngesterOutput: ... diff --git a/server/osa/domain/source/__init__.py b/server/osa/domain/source/__init__.py deleted file mode 100644 index a001339..0000000 --- a/server/osa/domain/source/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source domain - configuration and orchestration for data sources.""" diff --git a/server/osa/domain/source/event/__init__.py b/server/osa/domain/source/event/__init__.py deleted file mode 100644 index 9d82d2f..0000000 --- a/server/osa/domain/source/event/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Source domain events.""" - -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted - -__all__ = ["SourceRecordReady", "SourceRequested", "SourceRunCompleted"] diff --git a/server/osa/domain/source/event/source_record_ready.py b/server/osa/domain/source/event/source_record_ready.py deleted file mode 100644 index 5d19b0c..0000000 --- a/server/osa/domain/source/event/source_record_ready.py +++ /dev/null @@ -1,21 +0,0 @@ -"""SourceRecordReady event — emitted per record produced by a source container.""" - -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRecordReady(Event): - """Emitted for each record produced by a source run. - - Replaces direct DepositionService calls in SourceService. - Consumed by CreateDepositionFromSource in the deposition domain. - """ - - id: EventId - convention_srn: ConventionSRN - metadata: dict[str, Any] - file_paths: list[str] - source_id: str - staging_dir: str diff --git a/server/osa/domain/source/event/source_requested.py b/server/osa/domain/source/event/source_requested.py deleted file mode 100644 index dc213dd..0000000 --- a/server/osa/domain/source/event/source_requested.py +++ /dev/null @@ -1,28 +0,0 @@ -"""SourceRequested event - triggers pulling from a data source.""" - -from datetime import datetime -from typing import Any - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRequested(Event): - """Emitted when pulling should start for a source. - - The convention SRN identifies which convention (and its SourceDefinition) - to run. The server loads the convention to get the source image/config. - - For chunked processing: - - `offset`: Starting position for this chunk (0 for first chunk) - - `chunk_size`: Number of records to process per chunk - - `session`: Opaque pagination state for efficient continuation - """ - - id: EventId - convention_srn: ConventionSRN - since: datetime | None = None - limit: int | None = None - offset: int = 0 - chunk_size: int = 1000 - session: dict[str, Any] | None = None diff --git a/server/osa/domain/source/event/source_run_completed.py b/server/osa/domain/source/event/source_run_completed.py deleted file mode 100644 index 0b93d93..0000000 --- a/server/osa/domain/source/event/source_run_completed.py +++ /dev/null @@ -1,17 +0,0 @@ -"""SourceRunCompleted event - emitted after a source run finishes.""" - -from datetime import datetime - -from osa.domain.shared.event import Event, EventId -from osa.domain.shared.model.srn import ConventionSRN - - -class SourceRunCompleted(Event): - """Emitted after a source run completes.""" - - id: EventId - convention_srn: ConventionSRN - started_at: datetime - completed_at: datetime - record_count: int - is_final_chunk: bool = True diff --git a/server/osa/domain/source/handler/__init__.py b/server/osa/domain/source/handler/__init__.py deleted file mode 100644 index 3cc4466..0000000 --- a/server/osa/domain/source/handler/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Source domain event handlers.""" - -from osa.domain.source.handler.pull_from_source import PullFromSource -from osa.domain.source.handler.trigger_initial_source_run import TriggerInitialSourceRun - -__all__ = ["PullFromSource", "TriggerInitialSourceRun"] diff --git a/server/osa/domain/source/handler/pull_from_source.py b/server/osa/domain/source/handler/pull_from_source.py deleted file mode 100644 index f611c57..0000000 --- a/server/osa/domain/source/handler/pull_from_source.py +++ /dev/null @@ -1,34 +0,0 @@ -"""PullFromSource - handles SourceRequested events.""" - -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.shared.error import NotFoundError -from osa.domain.shared.event import EventHandler -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.service import SourceService - - -class PullFromSource(EventHandler[SourceRequested]): - """Runs a source container and emits per-record events. - - Looks up the convention to get the SourceDefinition, then - delegates to SourceService for container execution. - """ - - service: SourceService - convention_service: ConventionService - - async def handle(self, event: SourceRequested) -> None: - """Look up convention, extract source definition, and run source.""" - convention = await self.convention_service.get_convention(event.convention_srn) - if convention.source is None: - raise NotFoundError(f"No source defined for convention {event.convention_srn}") - - await self.service.run_source( - convention_srn=event.convention_srn, - source=convention.source, - since=event.since, - limit=event.limit, - offset=event.offset, - chunk_size=event.chunk_size, - session=event.session, - ) diff --git a/server/osa/domain/source/handler/trigger_initial_source_run.py b/server/osa/domain/source/handler/trigger_initial_source_run.py deleted file mode 100644 index c3a84bd..0000000 --- a/server/osa/domain/source/handler/trigger_initial_source_run.py +++ /dev/null @@ -1,39 +0,0 @@ -"""TriggerInitialSourceRun - triggers source pull when feature tables are ready.""" - -import logging -from uuid import uuid4 - -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.feature.event.convention_ready import ConventionReady -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested - -logger = logging.getLogger(__name__) - - -class TriggerInitialSourceRun(EventHandler[ConventionReady]): - """Emits SourceRequested when a convention with initial_run is ready. - - Part of the convention initialization chain: - ConventionRegistered → CreateFeatureTables → ConventionReady → TriggerInitialSourceRun - """ - - convention_service: ConventionService - outbox: Outbox - - async def handle(self, event: ConventionReady) -> None: - conv = await self.convention_service.get_convention(event.convention_srn) - - if conv.source is None or conv.source.initial_run is None: - return - - logger.info("Source deploy trigger: convention=%s", conv.srn) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=conv.srn, - limit=conv.source.initial_run.limit, - ) - ) diff --git a/server/osa/domain/source/handler/trigger_source_on_deploy.py b/server/osa/domain/source/handler/trigger_source_on_deploy.py deleted file mode 100644 index 534a287..0000000 --- a/server/osa/domain/source/handler/trigger_source_on_deploy.py +++ /dev/null @@ -1,35 +0,0 @@ -"""TriggerSourceOnDeploy - triggers source pull when a convention with a source is deployed.""" - -import logging -from uuid import uuid4 - -from osa.domain.deposition.event.convention_registered import ConventionRegistered -from osa.domain.deposition.service.convention import ConventionService -from osa.domain.shared.event import EventHandler, EventId -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested - -logger = logging.getLogger(__name__) - - -class TriggerSourceOnDeploy(EventHandler[ConventionRegistered]): - """Emits SourceRequested when a convention with initial_run is deployed.""" - - convention_service: ConventionService - outbox: Outbox - - async def handle(self, event: ConventionRegistered) -> None: - conv = await self.convention_service.get_convention(event.convention_srn) - - if conv.source is None or conv.source.initial_run is None: - return - - logger.info("Source deploy trigger: convention=%s", conv.srn) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=conv.srn, - limit=conv.source.initial_run.limit, - ) - ) diff --git a/server/osa/domain/source/model/__init__.py b/server/osa/domain/source/model/__init__.py deleted file mode 100644 index 7dc2082..0000000 --- a/server/osa/domain/source/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source domain models.""" diff --git a/server/osa/domain/source/port/storage.py b/server/osa/domain/source/port/storage.py deleted file mode 100644 index 821119c..0000000 --- a/server/osa/domain/source/port/storage.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Storage port scoped to the source domain.""" - -from abc import abstractmethod -from pathlib import Path -from typing import Protocol - -from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.shared.port import Port - - -class SourceStoragePort(Port, Protocol): - """File storage operations used by the source domain.""" - - @abstractmethod - def get_source_staging_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: - """Staging dir for source-ingested files, isolated per run.""" - ... - - @abstractmethod - def get_source_output_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: - """Output dir for a source run (records.jsonl, session.json).""" - ... - - @abstractmethod - async def move_source_files_to_deposition( - self, - staging_dir: Path, - source_id: str, - deposition_srn: DepositionSRN, - ) -> None: - """Move source staging files into the deposition's canonical file location.""" - ... diff --git a/server/osa/domain/source/schedule/__init__.py b/server/osa/domain/source/schedule/__init__.py deleted file mode 100644 index d566437..0000000 --- a/server/osa/domain/source/schedule/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Source scheduled tasks.""" - -from osa.domain.source.schedule.source_schedule import SourceSchedule - -__all__ = ["SourceSchedule"] diff --git a/server/osa/domain/source/schedule/source_schedule.py b/server/osa/domain/source/schedule/source_schedule.py deleted file mode 100644 index 7c9c8b2..0000000 --- a/server/osa/domain/source/schedule/source_schedule.py +++ /dev/null @@ -1,60 +0,0 @@ -"""SourceSchedule - scheduled task that emits SourceRequested events.""" - -import logging -from dataclasses import dataclass -from typing import Any -from uuid import uuid4 - -from osa.domain.shared.event import EventId, Schedule -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.shared.outbox import Outbox -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted - -logger = logging.getLogger(__name__) - - -@dataclass -class SourceSchedule(Schedule): - """Scheduled task that emits SourceRequested events. - - Looks up the last completed source run to determine the `since` timestamp, - then emits a SourceRequested event to trigger a new pull. - """ - - outbox: Outbox - - async def run(self, **params: Any) -> None: - """Emit a SourceRequested event for the given convention. - - Params: - convention: Convention SRN string - limit: Optional limit on records to fetch - """ - convention_srn = ConventionSRN.parse(params["convention"]) - limit: int | None = params.get("limit") - - # Look up last completed run for this convention - last_run = await self.outbox.find_latest_where( - SourceRunCompleted, convention_srn=str(convention_srn) - ) - - since = None - if last_run is not None: - since = last_run.completed_at - - logger.info( - "Scheduled source run: convention=%s (since=%s, limit=%s)", - convention_srn, - since, - limit, - ) - - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=convention_srn, - since=since, - limit=limit, - ) - ) diff --git a/server/osa/domain/source/service/__init__.py b/server/osa/domain/source/service/__init__.py deleted file mode 100644 index ecefd99..0000000 --- a/server/osa/domain/source/service/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Source service module.""" - -from osa.domain.source.service.source import SourceResult, SourceService - -__all__ = ["SourceService", "SourceResult"] diff --git a/server/osa/domain/source/service/source.py b/server/osa/domain/source/service/source.py deleted file mode 100644 index 03d2040..0000000 --- a/server/osa/domain/source/service/source.py +++ /dev/null @@ -1,154 +0,0 @@ -"""SourceService - orchestrates running OCI source containers.""" - -import logging -from dataclasses import dataclass -from datetime import UTC, datetime -from typing import Any -from uuid import uuid4 - -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.shared.outbox import Outbox -from osa.domain.shared.service import Service -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceInputs, SourceRunner -from osa.domain.source.port.storage import SourceStoragePort - -logger = logging.getLogger(__name__) - - -@dataclass -class SourceResult: - """Result of a source run.""" - - convention_srn: ConventionSRN - record_count: int - started_at: datetime - completed_at: datetime - - -class SourceService(Service): - """Orchestrates running source containers. - - For each record produced by the source container, emits a - SourceRecordReady event. The deposition domain handles creating - depositions from these events. - """ - - source_runner: SourceRunner - source_storage: SourceStoragePort - outbox: Outbox - - async def run_source( - self, - convention_srn: ConventionSRN, - source: SourceDefinition, - since: datetime | None = None, - limit: int | None = None, - offset: int = 0, - chunk_size: int = 1000, - session: dict[str, Any] | None = None, - ) -> SourceResult: - """Run a source container and emit events for each produced record.""" - started_at = datetime.now(UTC) - run_id = str(uuid4())[:12] - - logger.info( - "Starting source run for %s (run=%s, since=%s, limit=%s, offset=%s)", - convention_srn, - run_id, - since, - limit, - offset, - ) - - # Prepare dirs - staging_dir = self.source_storage.get_source_staging_dir(convention_srn, run_id) - work_dir = self.source_storage.get_source_output_dir(convention_srn, run_id) - - # Build inputs - inputs = SourceInputs( - convention_srn=convention_srn, - config=source.config, - since=since, - limit=limit, - offset=offset, - session=session, - ) - - # Run container - output = await self.source_runner.run( - source=source, - inputs=inputs, - files_dir=staging_dir, - work_dir=work_dir, - ) - - # Emit per-record events - count = 0 - for record_data in output.records: - source_id = record_data.get("source_id", "") - metadata = record_data.get("metadata", {}) - file_paths = record_data.get("file_paths", []) - - await self.outbox.append( - SourceRecordReady( - id=EventId(uuid4()), - convention_srn=convention_srn, - metadata=metadata, - file_paths=file_paths, - source_id=source_id, - staging_dir=str(staging_dir), - ) - ) - count += 1 - - if count % 100 == 0: - logger.info(" Emitted %d SourceRecordReady events so far...", count) - - completed_at = datetime.now(UTC) - is_final_chunk = output.session is None or count == 0 - - logger.info( - "Source run completed: %d records (run=%s, is_final=%s)", - count, - run_id, - is_final_chunk, - ) - - # Emit continuation if session exists - if not is_final_chunk: - next_offset = offset + count - logger.info("Emitting continuation event, next_offset=%d", next_offset) - await self.outbox.append( - SourceRequested( - id=EventId(uuid4()), - convention_srn=convention_srn, - since=since, - limit=limit, - offset=next_offset, - chunk_size=chunk_size, - session=output.session, - ) - ) - - await self.outbox.append( - SourceRunCompleted( - id=EventId(uuid4()), - convention_srn=convention_srn, - started_at=started_at, - completed_at=completed_at, - record_count=count, - is_final_chunk=is_final_chunk, - ) - ) - - return SourceResult( - convention_srn=convention_srn, - record_count=count, - started_at=started_at, - completed_at=completed_at, - ) diff --git a/server/osa/domain/validation/model/batch_outcome.py b/server/osa/domain/validation/model/batch_outcome.py new file mode 100644 index 0000000..4206a02 --- /dev/null +++ b/server/osa/domain/validation/model/batch_outcome.py @@ -0,0 +1,20 @@ +"""Per-record outcome from a batch hook run.""" + +from typing import Any + +from osa.domain.shared.model.value import ValueObject + + +class BatchRecordOutcome(ValueObject): + """Per-record outcome from a batch hook execution. + + Each record in a batch ends up in exactly one of three states: + passed (with features), rejected (with reason), or errored. + """ + + record_id: str + status: str # "passed", "rejected", "errored" + features: list[dict[str, Any]] = [] + reason: str | None = None + error: str | None = None + retryable: bool = False diff --git a/server/osa/domain/validation/model/hook_input.py b/server/osa/domain/validation/model/hook_input.py new file mode 100644 index 0000000..8c07075 --- /dev/null +++ b/server/osa/domain/validation/model/hook_input.py @@ -0,0 +1,15 @@ +"""Value objects for hook input data.""" + +from typing import Any + +from osa.domain.shared.model.value import ValueObject + + +class HookRecord(ValueObject): + """A single record to be processed by a hook. + + Maps to one line in records.jsonl: {"id": "...", "metadata": {...}}. + """ + + id: str + metadata: dict[str, Any] diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index 703596b..d1695f9 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -1,22 +1,28 @@ """Port for executing hooks in OCI containers.""" from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Protocol, runtime_checkable from osa.domain.shared.model.hook import HookDefinition from osa.domain.shared.port import Port +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.model.hook_result import HookResult @dataclass(frozen=True) class HookInputs: - """Inputs to pass to a hook container.""" + """Inputs to pass to a hook container. - record_json: dict + Uses the unified batch contract: records is a list of HookRecord + (1 for depositions, N for harvests). + files_dirs maps record ID → directory containing that record's files. + """ + + records: list[HookRecord] run_id: str - files_dir: Path | None = None + files_dirs: dict[str, Path] = field(default_factory=dict) config: dict | None = None diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index 0783d9b..c530b3c 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -18,6 +18,7 @@ RunStatus, ValidationRun, ) +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.domain.validation.port.repository import ValidationRunRepository @@ -101,14 +102,18 @@ async def validate_deposition( metadata: dict[str, Any], hooks: list[HookDefinition], ) -> tuple[ValidationRun, list[HookResult]]: - """Full validation workflow using enriched event data.""" - record_json = {"srn": str(deposition_srn), "metadata": metadata} + """Full validation workflow using enriched event data. + + Uses the unified batch contract: constructs a 1-record batch for depositions. + """ + record = HookRecord(id=str(deposition_srn), metadata=metadata) run_id = f"{deposition_srn.domain.root}_{deposition_srn.id.root}" files_dir = self.hook_storage.get_files_dir(deposition_srn) + inputs = HookInputs( - record_json=record_json, + records=[record], run_id=run_id, - files_dir=files_dir, + files_dirs={str(deposition_srn): files_dir} if files_dir else {}, ) run = await self.create_run(inputs=inputs) diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index 286b51e..e9c6d11 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -6,16 +6,19 @@ from dishka import AsyncContainer, provide from osa.domain.curation.handler import AutoApproveCuration -from osa.domain.deposition.handler import CreateDepositionFromSource, ReturnToDraft -from osa.domain.feature.handler import CreateFeatureTables, InsertRecordFeatures +from osa.domain.deposition.handler import ReturnToDraft +from osa.domain.feature.handler import ( + CreateFeatureTables, + InsertBatchFeatures, + InsertRecordFeatures, +) +from osa.domain.ingest.handler import PublishBatch, RunHooks, RunIngester from osa.domain.record.handler import ConvertDepositionToRecord from osa.domain.shared.event import EventHandler from osa.domain.shared.event_log import EventLog from osa.domain.shared.model.subscription_registry import SubscriptionRegistry from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.event_repository import EventRepository -from osa.domain.source.handler import PullFromSource, TriggerInitialSourceRun -from osa.domain.source.schedule import SourceSchedule from osa.domain.validation.handler import ValidateDeposition from osa.infrastructure.event.worker import WorkerPool from osa.util.di.base import Provider @@ -32,13 +35,14 @@ # Feature handlers (must run before source triggers) CreateFeatureTables, InsertRecordFeatures, - # Source handlers - TriggerInitialSourceRun, - PullFromSource, + InsertBatchFeatures, + # Ingest handlers + RunIngester, + RunHooks, + PublishBatch, # Validation handlers ValidateDeposition, # Deposition handlers - CreateDepositionFromSource, ReturnToDraft, # Curation handlers AutoApproveCuration, @@ -107,9 +111,6 @@ def get_outbox(self, repo: EventRepository, registry: SubscriptionRegistry) -> O def get_event_log(self, repo: EventRepository) -> EventLog: return EventLog(repo) - # UOW-scoped provider for SourceSchedule - source_schedule = provide(SourceSchedule, scope=Scope.UOW) - @provide(scope=Scope.APP) def get_handler_types(self) -> HandlerTypes: """Return all handler types (core + extra) for WorkerPool registration.""" diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 24924f6..d629adc 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -311,36 +311,7 @@ async def start(self) -> None: async def _build_schedules_from_conventions(self) -> list[ScheduleConfig]: """Query conventions with sources and build schedule configs.""" - if self._container is None: - return [] - - from osa.domain.deposition.service.convention import ConventionService - from osa.domain.source.schedule import SourceSchedule as SourceScheduleType - - configs: list[ScheduleConfig] = [] - try: - async with self._container(scope=Scope.UOW, context={Identity: System()}) as scope: - convention_service = await scope.get(ConventionService) - conventions = await convention_service.list_conventions_with_source() - - for conv in conventions: - if conv.source is None or conv.source.schedule is None: - continue - configs.append( - ScheduleConfig( - schedule_type=SourceScheduleType, - cron=conv.source.schedule.cron, - id=f"source-{conv.srn}", - params={ - "convention": str(conv.srn), - "limit": conv.source.schedule.limit, - }, - ) - ) - except Exception as e: - logger.warning(f"Failed to build schedules from conventions: {e}") - - return configs + return [] async def stop(self, timeout: float = 30.0) -> None: """Stop all workers gracefully.""" diff --git a/server/osa/infrastructure/ingest/__init__.py b/server/osa/infrastructure/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/infrastructure/ingest/di.py b/server/osa/infrastructure/ingest/di.py new file mode 100644 index 0000000..5aef81b --- /dev/null +++ b/server/osa/infrastructure/ingest/di.py @@ -0,0 +1,39 @@ +"""Dependency injection provider for ingest domain.""" + +from dishka import provide +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.deposition.service.convention import ConventionService +from osa.domain.ingest.command.start_ingest import StartIngestHandler +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.model.srn import Domain +from osa.domain.shared.outbox import Outbox +from osa.infrastructure.persistence.repository.ingest import PostgresIngestRunRepository +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class IngestProvider(Provider): + """Provides IngestService, IngestRunRepository, and StartIngestHandler.""" + + @provide(scope=Scope.UOW) + def get_ingest_repo(self, session: AsyncSession) -> IngestRunRepository: + return PostgresIngestRunRepository(session) + + @provide(scope=Scope.UOW) + def get_ingest_service( + self, + ingest_repo: IngestRunRepository, + convention_service: ConventionService, + outbox: Outbox, + node_domain: Domain, + ) -> IngestService: + return IngestService( + ingest_repo=ingest_repo, + convention_service=convention_service, + outbox=outbox, + node_domain=node_domain, + ) + + start_ingest_handler = provide(StartIngestHandler, scope=Scope.UOW) diff --git a/server/osa/infrastructure/k8s/di.py b/server/osa/infrastructure/k8s/di.py index ebde8d3..1b6d99b 100644 --- a/server/osa/infrastructure/k8s/di.py +++ b/server/osa/infrastructure/k8s/di.py @@ -13,10 +13,10 @@ from dishka import activate, provide from osa.config import Config -from osa.domain.source.port.source_runner import SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterRunner from osa.domain.validation.port.hook_runner import HookRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner from osa.infrastructure.s3.client import S3Client from osa.util.di.base import Provider from osa.util.di.markers import K8S @@ -62,12 +62,12 @@ def get_hook_runner_oci( return OciHookRunner(docker=docker, host_data_dir=config.host_data_dir) @provide(scope=Scope.UOW) - def get_source_runner_oci( + def get_ingester_runner_oci( self, docker: aiodocker.Docker, config: Config, - ) -> SourceRunner: - return OciSourceRunner(docker=docker, host_data_dir=config.host_data_dir) + ) -> IngesterRunner: + return OciIngesterRunner(docker=docker, host_data_dir=config.host_data_dir) # ------------------------------------------------------------------ # K8s backend (activated when config.runner.backend == "k8s") @@ -133,12 +133,12 @@ def get_hook_runner_k8s( return K8sHookRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) @provide(when=K8S, scope=Scope.UOW) - def get_source_runner_k8s( + def get_ingester_runner_k8s( self, k8s_api_client: ApiClient, config: Config, s3: S3Client, - ) -> SourceRunner: - from osa.infrastructure.k8s.source_runner import K8sSourceRunner + ) -> IngesterRunner: + from osa.infrastructure.k8s.ingester_runner import K8sIngesterRunner - return K8sSourceRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) + return K8sIngesterRunner(api_client=k8s_api_client, config=config.runner.k8s, s3=s3) diff --git a/server/osa/infrastructure/k8s/source_runner.py b/server/osa/infrastructure/k8s/ingester_runner.py similarity index 96% rename from server/osa/infrastructure/k8s/source_runner.py rename to server/osa/infrastructure/k8s/ingester_runner.py index 39cb877..066920c 100644 --- a/server/osa/infrastructure/k8s/source_runner.py +++ b/server/osa/infrastructure/k8s/ingester_runner.py @@ -1,4 +1,4 @@ -"""Kubernetes Job-based source runner.""" +"""Kubernetes Job-based ingester runner.""" from __future__ import annotations @@ -11,9 +11,9 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError, InfrastructureError -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.k8s.errors import classify_api_error from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label from osa.infrastructure.runner_utils import ( @@ -31,7 +31,7 @@ SCHEDULING_TIMEOUT = 120 -class K8sSourceRunner(SourceRunner): +class K8sIngesterRunner(IngesterRunner): """Executes sources as Kubernetes Jobs. Key differences from K8sHookRunner: @@ -54,11 +54,11 @@ def _s3_prefix(self, work_dir: Path, subdir: str) -> str: async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: + ) -> IngesterOutput: try: from kubernetes_asyncio.client import BatchV1Api, CoreV1Api except ImportError: @@ -95,13 +95,13 @@ async def _run_job( self, batch_api: BatchV1Api, core_api: CoreV1Api, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, work_dir: Path, files_dir: Path, *, convention_srn: ConventionSRN | None = None, - ) -> SourceOutput: + ) -> IngesterOutput: """Core Job lifecycle for source execution.""" namespace = self._config.namespace job_name_to_watch = None @@ -169,7 +169,7 @@ async def _run_job( if job_name_to_watch: await self._cleanup_job(batch_api, job_name_to_watch, namespace) - async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceOutput: + async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> IngesterOutput: from osa.infrastructure.runner_utils import ( parse_records_from_s3, parse_session_from_s3, @@ -178,7 +178,7 @@ async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> SourceO output_prefix = self._s3_prefix(work_dir, "output") records = await parse_records_from_s3(self._s3, output_prefix) session = await parse_session_from_s3(self._s3, output_prefix) - return SourceOutput(records=records, session=session, files_dir=files_dir) + return IngesterOutput(records=records, session=session, files_dir=files_dir) async def _check_existing_job( self, @@ -208,11 +208,11 @@ async def _check_existing_job( def _build_job_spec( self, - source: SourceDefinition, + source: IngesterDefinition, *, work_dir: Path, files_dir: Path, - inputs: SourceInputs | None = None, + inputs: IngesterInputs | None = None, convention_srn: ConventionSRN | None = None, ) -> V1Job: from kubernetes_asyncio.client import ( @@ -420,7 +420,7 @@ async def _diagnose_and_raise( core_api: CoreV1Api, job_name: str, namespace: str, - source: SourceDefinition, + source: IngesterDefinition, failure_info: str, ) -> None: """Determine failure reason and raise appropriate error.""" diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 457c799..b03f91c 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -71,7 +71,9 @@ async def run( # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") - await self._s3.put_object(f"{input_prefix}/record.json", json.dumps(inputs.record_json)) + # Write records.jsonl (unified batch contract) + records_jsonl = "\n".join(json.dumps(r.model_dump()) for r in inputs.records) + "\n" + await self._s3.put_object(f"{input_prefix}/records.jsonl", records_jsonl) if inputs.config or hook.runtime.config: config = {**hook.runtime.config, **(inputs.config or {})} await self._s3.put_object(f"{input_prefix}/config.json", json.dumps(config)) @@ -113,11 +115,16 @@ async def _run_job( job_name_to_watch = existing.split(":", 1)[1] else: # Create new Job (no existing or failed) + # For depositions (batch of 1), use the single record's files dir + files_dir = None + if inputs.files_dirs: + first_id = next(iter(inputs.files_dirs)) + files_dir = inputs.files_dirs[first_id] spec = self._build_job_spec( hook, work_dir, run_id=inputs.run_id, - files_dir=inputs.files_dir, + files_dir=files_dir, ) job_name_to_watch = spec.metadata.name @@ -262,6 +269,7 @@ def _build_job_spec( V1VolumeMount(name="tmp", mount_path="/tmp"), ] + # Mount per-record file directories (ingest: multiple, deposition: one) if files_dir: relative_files = self._relative_path(files_dir) mounts.append( diff --git a/server/osa/infrastructure/oci/di.py b/server/osa/infrastructure/oci/di.py index 5c51ceb..973a158 100644 --- a/server/osa/infrastructure/oci/di.py +++ b/server/osa/infrastructure/oci/di.py @@ -4,10 +4,10 @@ from dishka import provide from osa.config import Config -from osa.domain.source.port.source_runner import SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterRunner from osa.domain.validation.port.hook_runner import HookRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -24,5 +24,5 @@ def get_hook_runner(self, docker: aiodocker.Docker, config: Config) -> HookRunne return OciHookRunner(docker=docker, host_data_dir=config.host_data_dir) @provide(scope=Scope.UOW) - def get_source_runner(self, docker: aiodocker.Docker, config: Config) -> SourceRunner: - return OciSourceRunner(docker=docker, host_data_dir=config.host_data_dir) + def get_ingester_runner(self, docker: aiodocker.Docker, config: Config) -> IngesterRunner: + return OciIngesterRunner(docker=docker, host_data_dir=config.host_data_dir) diff --git a/server/osa/infrastructure/oci/source_runner.py b/server/osa/infrastructure/oci/ingester_runner.py similarity index 93% rename from server/osa/infrastructure/oci/source_runner.py rename to server/osa/infrastructure/oci/ingester_runner.py index 2f261ce..ebdb35f 100644 --- a/server/osa/infrastructure/oci/source_runner.py +++ b/server/osa/infrastructure/oci/ingester_runner.py @@ -1,4 +1,4 @@ -"""OCI source runner using aiodocker.""" +"""OCI ingester runner using aiodocker.""" import asyncio import json @@ -11,8 +11,8 @@ import logfire from osa.domain.shared.error import ExternalServiceError -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.model.source import IngesterDefinition +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.runner_utils import ( parse_memory, parse_records_file, @@ -20,7 +20,7 @@ ) -class OciSourceRunner(SourceRunner): +class OciIngesterRunner(IngesterRunner): """Executes sources in OCI containers via aiodocker. Key differences from OciHookRunner: @@ -47,11 +47,11 @@ def __init__( async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: + ) -> IngesterOutput: timeout = source.limits.timeout_seconds from shutil import rmtree @@ -108,9 +108,9 @@ async def _run_container( staging_dir: Path, files_dir: Path, output_dir: Path, - source: SourceDefinition, - inputs: SourceInputs, - ) -> SourceOutput: + source: IngesterDefinition, + inputs: IngesterInputs, + ) -> IngesterOutput: container = None try: # Build env vars @@ -168,7 +168,7 @@ async def _run_container( records = parse_records_file(output_dir) session = parse_session_file(output_dir) - return SourceOutput(records=records, session=session, files_dir=files_dir) + return IngesterOutput(records=records, session=session, files_dir=files_dir) except aiodocker.DockerError as e: logfire.error("Docker error running source", error=str(e)) diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index 7878991..f19e8b8 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -54,14 +54,19 @@ async def run( container_output = work_dir / "output" container_output.mkdir(parents=True, exist_ok=True) try: - (staging_dir / "record.json").write_text(json.dumps(inputs.record_json)) - # Pre-create files mountpoint so nested bind works with ReadonlyRootfs - (staging_dir / "files").mkdir(exist_ok=True) + # Write records.jsonl (unified batch contract) + with (staging_dir / "records.jsonl").open("w") as f: + for record in inputs.records: + f.write(json.dumps(record.model_dump()) + "\n") if inputs.config or hook.runtime.config: config = {**hook.runtime.config, **(inputs.config or {})} (staging_dir / "config.json").write_text(json.dumps(config)) + # Create files directory structure: $OSA_FILES/{id}/ per record + files_base = staging_dir / "files" + files_base.mkdir(exist_ok=True) + start_time = time.monotonic() try: @@ -69,7 +74,12 @@ async def run( async def _resolve_and_run(): image_ref = await self._resolve_image(hook.runtime.image, hook.runtime.digest) return await self._run_container( - image_ref, staging_dir, inputs.files_dir, container_output, hook + image_ref, + staging_dir, + inputs.files_dirs, + container_output, + hook, + files_base, ) result = await asyncio.wait_for( @@ -106,19 +116,26 @@ async def _run_container( self, image_ref: str, staging_dir: Path, - files_dir: Path | None, + files_dirs: dict[str, Path], output_dir: Path, hook: HookDefinition, + files_base: Path, ) -> dict: container = None try: - # Nested bind-mounts: staging at /osa/in:ro, files at /osa/in/files:ro + # Bind mounts: staging at /osa/in:ro, files at /osa/files:ro, output at /osa/out:rw binds = [ f"{self._host_path(staging_dir)}:/osa/in:ro", f"{self._host_path(output_dir)}:/osa/out:rw", ] - if files_dir and files_dir.exists(): - binds.append(f"{self._host_path(files_dir)}:/osa/in/files:ro") + + # Mount per-record file directories under /osa/files/{id}/ + if files_dirs: + for record_id, fdir in files_dirs.items(): + if fdir and fdir.exists(): + binds.append(f"{self._host_path(fdir)}:/osa/files/{record_id}:ro") + elif files_base.exists(): + binds.append(f"{self._host_path(files_base)}:/osa/files:ro") # todo: use pydantic config = { @@ -126,6 +143,7 @@ async def _run_container( "Env": [ "OSA_IN=/osa/in", "OSA_OUT=/osa/out", + "OSA_FILES=/osa/files", f"OSA_HOOK_NAME={hook.name}", ], "User": "65534:65534", diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index 04c01da..c08711a 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -12,6 +12,7 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome logger = logging.getLogger(__name__) @@ -58,6 +59,8 @@ def get_hook_output_root(self, source_type: str, source_id: str) -> str: if source_type == "deposition": srn = DepositionSRN.parse(source_id) return str(self._dep_dir(srn)) + if source_type == "ingest": + return str(self.base_path / "ingests" / source_id) raise ValueError(f"Unknown source type: {source_type}") async def read_hook_features( @@ -189,3 +192,43 @@ async def move_source_files_to_deposition( # Clean up empty source_id directory if source_files_dir.exists(): source_files_dir.rmdir() + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[str, BatchRecordOutcome]: + """Read JSONL batch outputs from the filesystem, streaming line-by-line.""" + hook_output = Path(output_dir) / "hooks" / hook_name / "output" + outcomes: dict[str, BatchRecordOutcome] = {} + + for filename, status_key, field_map in [ + ("features.jsonl", "passed", {"features": "features"}), + ("rejections.jsonl", "rejected", {"reason": "reason"}), + ("errors.jsonl", "errored", {"error": "error", "retryable": "retryable"}), + ]: + path = hook_output / filename + if not path.exists(): + continue + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSON line in %s", filename) + continue + record_id = data.get("id") + if not record_id: + logger.warning("Skipping JSONL line without 'id' in %s", filename) + continue + kwargs: dict[str, Any] = { + "record_id": record_id, + "status": status_key, + } + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) + + return outcomes diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 4b4d75b..569945f 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -17,7 +17,6 @@ from osa.domain.record.query.get_record import GetRecordHandler from osa.domain.record.service import RecordService from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader -from osa.domain.source.port.storage import SourceStoragePort from osa.domain.feature.port.storage import FeatureStoragePort from osa.domain.validation.port.storage import HookStoragePort from osa.domain.semantics.port.ontology_repository import OntologyRepository @@ -132,10 +131,6 @@ def get_file_storage_s3(self, config: Config, s3: "S3Client") -> FileStoragePort return S3StorageAdapter(s3=s3, data_mount_path=config.runner.k8s.data_mount_path) - @provide(scope=Scope.APP) - def get_source_storage(self, file_storage: FileStoragePort) -> SourceStoragePort: - return file_storage # type: ignore[return-value] - @provide(scope=Scope.APP) def get_hook_storage(self, file_storage: FileStoragePort) -> HookStoragePort: return file_storage # type: ignore[return-value] diff --git a/server/osa/infrastructure/persistence/repository/convention.py b/server/osa/infrastructure/persistence/repository/convention.py index 570e6a3..626f4cc 100644 --- a/server/osa/infrastructure/persistence/repository/convention.py +++ b/server/osa/infrastructure/persistence/repository/convention.py @@ -7,7 +7,7 @@ from osa.domain.deposition.model.value import FileRequirements from osa.domain.deposition.port.convention_repository import ConventionRepository from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.infrastructure.persistence.tables import conventions_table @@ -20,7 +20,7 @@ def _convention_to_row(convention: Convention) -> dict[str, Any]: "schema_srn": str(convention.schema_srn), "file_requirements": convention.file_requirements.model_dump(), "hooks": [h.model_dump() for h in convention.hooks], - "source": convention.source.model_dump() if convention.source else None, + "source": convention.ingester.model_dump() if convention.ingester else None, "created_at": convention.created_at, } @@ -34,7 +34,7 @@ def _row_to_convention(row: dict[str, Any]) -> Convention: schema_srn=SchemaSRN.parse(row["schema_srn"]), file_requirements=FileRequirements.model_validate(row["file_requirements"]), hooks=[HookDefinition.model_validate(h) for h in (row.get("hooks") or [])], - source=SourceDefinition.model_validate(source_data) if source_data else None, + ingester=IngesterDefinition.model_validate(source_data) if source_data else None, created_at=row["created_at"], ) diff --git a/server/osa/infrastructure/persistence/repository/ingest.py b/server/osa/infrastructure/persistence/repository/ingest.py new file mode 100644 index 0000000..f46fed2 --- /dev/null +++ b/server/osa/infrastructure/persistence/repository/ingest.py @@ -0,0 +1,127 @@ +"""PostgreSQL implementation of IngestRunRepository.""" + +import logging + +from sqlalchemy import select, update +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.ingest.port.repository import IngestRunRepository +from osa.infrastructure.persistence.tables import ingest_runs_table + +logger = logging.getLogger(__name__) + + +class PostgresIngestRunRepository(IngestRunRepository): + """PostgreSQL implementation with atomic counter updates.""" + + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def save(self, ingest_run: IngestRun) -> None: + """Insert or update an ingest run.""" + values = { + "srn": ingest_run.srn, + "convention_srn": ingest_run.convention_srn, + "status": ingest_run.status.value, + "source_finished": ingest_run.source_finished, + "batches_sourced": ingest_run.batches_sourced, + "batches_completed": ingest_run.batches_completed, + "published_count": ingest_run.published_count, + "batch_size": ingest_run.batch_size, + "started_at": ingest_run.started_at, + "completed_at": ingest_run.completed_at, + } + stmt = ( + insert(ingest_runs_table) + .values(**values) + .on_conflict_do_update( + index_elements=["srn"], + set_=values, + ) + ) + await self._session.execute(stmt) + await self._session.flush() + + async def get(self, srn: str) -> IngestRun | None: + stmt = select(ingest_runs_table).where(ingest_runs_table.c.srn == srn) + result = await self._session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + return _row_to_ingest_run(dict(row)) + + async def get_running_for_convention(self, convention_srn: str) -> IngestRun | None: + stmt = ( + select(ingest_runs_table) + .where(ingest_runs_table.c.convention_srn == convention_srn) + .where( + ingest_runs_table.c.status.in_( + [IngestStatus.PENDING.value, IngestStatus.RUNNING.value] + ) + ) + .limit(1) + ) + result = await self._session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + return _row_to_ingest_run(dict(row)) + + async def increment_batches_sourced( + self, srn: str, *, set_source_finished: bool = False + ) -> IngestRun: + """Atomically increment batches_sourced.""" + t = ingest_runs_table + values = { + "batches_sourced": t.c.batches_sourced + 1, + } + if set_source_finished: + values["source_finished"] = True + + stmt = update(t).where(t.c.srn == srn).values(**values).returning(*t.c) + result = await self._session.execute(stmt) + await self._session.flush() + row = result.mappings().first() + if row is None: + from osa.domain.shared.error import NotFoundError + + raise NotFoundError(f"Ingest run not found: {srn}") + return _row_to_ingest_run(dict(row)) + + async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + """Atomically increment batches_completed and published_count.""" + t = ingest_runs_table + stmt = ( + update(t) + .where(t.c.srn == srn) + .values( + batches_completed=t.c.batches_completed + 1, + published_count=t.c.published_count + published_count, + ) + .returning(*t.c) + ) + result = await self._session.execute(stmt) + await self._session.flush() + row = result.mappings().first() + if row is None: + from osa.domain.shared.error import NotFoundError + + raise NotFoundError(f"Ingest run not found: {srn}") + return _row_to_ingest_run(dict(row)) + + +def _row_to_ingest_run(row: dict) -> IngestRun: + return IngestRun( + srn=row["srn"], + convention_srn=row["convention_srn"], + status=IngestStatus(row["status"]), + source_finished=row["source_finished"], + batches_sourced=row["batches_sourced"], + batches_completed=row["batches_completed"], + published_count=row["published_count"], + batch_size=row["batch_size"], + started_at=row["started_at"], + completed_at=row.get("completed_at"), + ) diff --git a/server/osa/infrastructure/persistence/repository/record.py b/server/osa/infrastructure/persistence/repository/record.py index 22763b9..f143cf0 100644 --- a/server/osa/infrastructure/persistence/repository/record.py +++ b/server/osa/infrastructure/persistence/repository/record.py @@ -1,6 +1,7 @@ """PostgreSQL implementation of RecordRepository.""" -from sqlalchemy import func, insert, select +from sqlalchemy import func, select +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.record.model.aggregate import Record @@ -23,6 +24,30 @@ async def save(self, record: Record) -> None: await self.session.execute(stmt) await self.session.flush() + async def save_many(self, records: list[Record]) -> list[Record]: + """Multi-row INSERT with ON CONFLICT DO NOTHING. + + Returns the records that were actually inserted (duplicates are skipped). + """ + if not records: + return [] + values = [record_to_dict(r) for r in records] + stmt = ( + insert(records_table) + .values(values) + .on_conflict_do_nothing( + index_elements=[ + records_table.c.source["type"].as_string(), + records_table.c.source["id"].as_string(), + ] + ) + .returning(records_table.c.srn) + ) + result = await self.session.execute(stmt) + await self.session.flush() + inserted_srns = {row[0] for row in result.fetchall()} + return [r for r in records if str(r.srn) in inserted_srns] + async def get(self, srn: RecordSRN) -> Record | None: """Get a record by SRN.""" stmt = select(records_table).where(records_table.c.srn == str(srn)) diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index de6e670..b7b9dce 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -266,7 +266,7 @@ Column("schema_srn", String, nullable=False), # Reference to schemas.srn Column("file_requirements", JSON, nullable=False), # FileRequirements as dict Column("hooks", JSON, nullable=False, default=[]), # List of HookDefinition dicts - Column("source", JSON, nullable=True), # SourceDefinition as dict + Column("source", JSON, nullable=True), # IngesterDefinition as dict Column("created_at", DateTime(timezone=True), nullable=False), ) @@ -304,6 +304,28 @@ Index("ix_role_assignments_user_id", role_assignments_table.c.user_id) +# ============================================================================ +# INGEST RUNS TABLE (Ingest) +# ============================================================================ +ingest_runs_table = Table( + "ingest_runs", + metadata, + Column("srn", String, primary_key=True), + Column("convention_srn", String, ForeignKey("conventions.srn"), nullable=False), + Column("status", String(32), nullable=False, server_default=text("'pending'")), + Column("source_finished", Boolean, nullable=False, server_default=text("false")), + Column("batches_sourced", Integer, nullable=False, server_default=text("0")), + Column("batches_completed", Integer, nullable=False, server_default=text("0")), + Column("published_count", Integer, nullable=False, server_default=text("0")), + Column("batch_size", Integer, nullable=False, server_default=text("1000")), + Column("started_at", DateTime(timezone=True), nullable=False), + Column("completed_at", DateTime(timezone=True), nullable=True), +) + +Index("idx_ingest_runs_convention", ingest_runs_table.c.convention_srn) +Index("idx_ingest_runs_status", ingest_runs_table.c.status) + + # ============================================================================ # DEVICE AUTHORIZATIONS TABLE (Authentication - OAuth Device Flow) # ============================================================================ diff --git a/server/osa/infrastructure/s3/storage.py b/server/osa/infrastructure/s3/storage.py index 2b37eda..5b0dcf0 100644 --- a/server/osa/infrastructure/s3/storage.py +++ b/server/osa/infrastructure/s3/storage.py @@ -16,6 +16,7 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError, NotFoundError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome from osa.infrastructure.runner_utils import relative_path from osa.infrastructure.s3.client import S3Client @@ -178,6 +179,8 @@ def get_hook_output_root(self, source_type: str, source_id: str) -> str: srn = DepositionSRN.parse(source_id) safe_id = self._safe_id(srn) return f"{self._data_mount_path}/depositions/{safe_id}" + if source_type == "ingest": + return f"{self._data_mount_path}/ingests/{source_id}" raise ValueError(f"Unknown source type: {source_type}") async def read_hook_features( @@ -200,3 +203,46 @@ async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> prefix = relative_path(Path(hook_output_dir), self._data_mount_path) key = f"{prefix}/hooks/{feature_name}/output/features.json" return await self._s3.head_object(key) + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[str, BatchRecordOutcome]: + """Read JSONL batch outputs from S3.""" + prefix = relative_path(Path(output_dir), self._data_mount_path) + hook_prefix = f"{prefix}/hooks/{hook_name}/output" + outcomes: dict[str, BatchRecordOutcome] = {} + + for filename, status_key, field_map in [ + ("features.jsonl", "passed", {"features": "features"}), + ("rejections.jsonl", "rejected", {"reason": "reason"}), + ("errors.jsonl", "errored", {"error": "error", "retryable": "retryable"}), + ]: + key = f"{hook_prefix}/{filename}" + try: + data_bytes = await self._s3.get_object(key) + except Exception: + continue + + for line in data_bytes.decode().split("\n"): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSON line in %s", filename) + continue + record_id = data.get("id") + if not record_id: + logger.warning("Skipping JSONL line without 'id' in %s", filename) + continue + kwargs: dict[str, Any] = { + "record_id": record_id, + "status": status_key, + } + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) + + return outcomes diff --git a/server/osa/infrastructure/source/__init__.py b/server/osa/infrastructure/source/__init__.py deleted file mode 100644 index 701b44d..0000000 --- a/server/osa/infrastructure/source/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Source infrastructure - source implementations.""" diff --git a/server/osa/infrastructure/source/di.py b/server/osa/infrastructure/source/di.py deleted file mode 100644 index eda51bd..0000000 --- a/server/osa/infrastructure/source/di.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Dependency injection provider for sources.""" - -from dishka import provide - -from osa.domain.shared.outbox import Outbox -from osa.domain.source.port.source_runner import SourceRunner -from osa.domain.source.port.storage import SourceStoragePort -from osa.domain.source.service import SourceService -from osa.util.di.base import Provider -from osa.util.di.scope import Scope - - -class SourceProvider(Provider): - """Provides SourceService wired with OCI runner.""" - - @provide(scope=Scope.UOW) - def get_source_service( - self, - source_runner: SourceRunner, - source_storage: SourceStoragePort, - outbox: Outbox, - ) -> SourceService: - return SourceService( - source_runner=source_runner, - source_storage=source_storage, - outbox=outbox, - ) diff --git a/server/tests/integration/persistence/test_convention_repo.py b/server/tests/integration/persistence/test_convention_repo.py index bc3ed27..632eb79 100644 --- a/server/tests/integration/persistence/test_convention_repo.py +++ b/server/tests/integration/persistence/test_convention_repo.py @@ -15,10 +15,10 @@ TableFeatureSpec, ) from osa.domain.shared.model.source import ( + IngesterDefinition, + IngesterLimits, + IngesterScheduleConfig, InitialRunConfig, - SourceDefinition, - SourceLimits, - SourceScheduleConfig, ) from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN from osa.infrastructure.persistence.repository.convention import ( @@ -32,7 +32,7 @@ def _make_convention( title: str = "Test Convention", schema_srn: str = "urn:osa:localhost:schema:test-schema-001@1.0.0", hooks: list[HookDefinition] | None = None, - source: SourceDefinition | None = None, + ingester: IngesterDefinition | None = None, ) -> Convention: return Convention( srn=ConventionSRN.parse(srn), @@ -46,7 +46,7 @@ def _make_convention( max_file_size=100_000_000, ), hooks=hooks or [], - source=source, + ingester=ingester, created_at=datetime.now(UTC), ) @@ -70,14 +70,14 @@ def _make_hook() -> HookDefinition: ) -def _make_source() -> SourceDefinition: - return SourceDefinition( - image="ghcr.io/example/source:latest", +def _make_ingester() -> IngesterDefinition: + return IngesterDefinition( + image="ghcr.io/example/ingester:latest", digest="sha256:def456", runner="oci", config={"api_key": "test-key"}, - limits=SourceLimits(timeout_seconds=7200, memory="8g", cpu="4.0"), - schedule=SourceScheduleConfig(cron="0 2 * * *", limit=500), + limits=IngesterLimits(timeout_seconds=7200, memory="8g", cpu="4.0"), + schedule=IngesterScheduleConfig(cron="0 2 * * *", limit=500), initial_run=InitialRunConfig(limit=100), ) @@ -88,8 +88,8 @@ async def test_save_and_get(self, pg_session: AsyncSession): """Save a convention and retrieve it — all fields should match.""" repo = PostgresConventionRepository(pg_session) hook = _make_hook() - source = _make_source() - conv = _make_convention(hooks=[hook], source=source) + ingester = _make_ingester() + conv = _make_convention(hooks=[hook], ingester=ingester) await repo.save(conv) await pg_session.commit() @@ -106,12 +106,12 @@ async def test_save_and_get(self, pg_session: AsyncSession): assert got.hooks[0].runtime.digest == hook.runtime.digest assert got.hooks[0].name == "quality_check" assert got.hooks[0].feature.columns[0].name == "score" - assert got.source is not None - assert got.source.image == source.image - assert got.source.schedule is not None - assert got.source.schedule.cron == "0 2 * * *" - assert got.source.initial_run is not None - assert got.source.initial_run.limit == 100 + assert got.ingester is not None + assert got.ingester.image == ingester.image + assert got.ingester.schedule is not None + assert got.ingester.schedule.cron == "0 2 * * *" + assert got.ingester.initial_run is not None + assert got.ingester.initial_run.limit == 100 async def test_get_nonexistent_returns_none(self, pg_session: AsyncSession): repo = PostgresConventionRepository(pg_session) @@ -163,14 +163,14 @@ async def test_exists_false(self, pg_session: AsyncSession): repo = PostgresConventionRepository(pg_session) assert await repo.exists(ConventionSRN.parse("urn:osa:localhost:conv:nope@1.0.0")) is False - async def test_convention_without_source(self, pg_session: AsyncSession): - """Source is optional — should be None on retrieval when not set.""" + async def test_convention_without_ingester(self, pg_session: AsyncSession): + """Ingester is optional — should be None on retrieval when not set.""" repo = PostgresConventionRepository(pg_session) - conv = _make_convention(source=None, hooks=[]) + conv = _make_convention(ingester=None, hooks=[]) await repo.save(conv) await pg_session.commit() got = await repo.get(conv.srn) assert got is not None - assert got.source is None + assert got.ingester is None assert got.hooks == [] diff --git a/server/tests/unit/application/__init__.py b/server/tests/unit/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/api/__init__.py b/server/tests/unit/application/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/api/v1/__init__.py b/server/tests/unit/application/api/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/application/test_app_factory.py b/server/tests/unit/application/test_app_factory.py index 6b276b4..609fb00 100644 --- a/server/tests/unit/application/test_app_factory.py +++ b/server/tests/unit/application/test_app_factory.py @@ -19,14 +19,14 @@ from osa.application.di import create_container from osa.domain.shared.event import Event, EventHandler, EventId from osa.domain.shared.model.hook import HookDefinition -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.subscription_registry import SubscriptionRegistry -from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.event.di import HandlerTypes, _CORE_HANDLERS from osa.infrastructure.oci.runner import OciHookRunner -from osa.infrastructure.oci.source_runner import OciSourceRunner +from osa.infrastructure.oci.ingester_runner import OciIngesterRunner from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -49,17 +49,17 @@ async def run(self, hook: HookDefinition, inputs: HookInputs, work_dir: Path) -> return HookResult(hook_name=hook.name, status=HookStatus.PASSED, duration_seconds=0.0) -class StubSourceRunner: - """Stub SourceRunner for testing provider overrides.""" +class StubIngesterRunner: + """Stub IngesterRunner for testing provider overrides.""" async def run( self, - source: SourceDefinition, - inputs: SourceInputs, + source: IngesterDefinition, + inputs: IngesterInputs, files_dir: Path, work_dir: Path, - ) -> SourceOutput: - return SourceOutput(records=[], session=None, files_dir=files_dir) + ) -> IngesterOutput: + return IngesterOutput(records=[], session=None, files_dir=files_dir) class StubRunnerProvider(Provider): @@ -70,8 +70,8 @@ def get_hook_runner(self) -> HookRunner: return StubHookRunner() @provide(scope=Scope.UOW, override=True) - def get_source_runner(self) -> SourceRunner: - return StubSourceRunner() + def get_ingester_runner(self) -> IngesterRunner: + return StubIngesterRunner() # --------------------------------------------------------------------------- @@ -110,12 +110,12 @@ def test_runner_override(self): async def resolve(): async with container(scope=Scope.UOW) as uow: hook = await uow.get(HookRunner) - source = await uow.get(SourceRunner) - return hook, source + ingester = await uow.get(IngesterRunner) + return hook, ingester - hook_runner, source_runner = asyncio.run(resolve()) + hook_runner, ingester_runner = asyncio.run(resolve()) assert isinstance(hook_runner, StubHookRunner) - assert isinstance(source_runner, StubSourceRunner) + assert isinstance(ingester_runner, StubIngesterRunner) def test_default_runners_without_override(self): """Without extra providers, default OCI runners are used.""" @@ -124,12 +124,12 @@ def test_default_runners_without_override(self): async def resolve(): async with container(scope=Scope.UOW) as uow: hook = await uow.get(HookRunner) - source = await uow.get(SourceRunner) - return hook, source + ingester = await uow.get(IngesterRunner) + return hook, ingester - hook_runner, source_runner = asyncio.run(resolve()) + hook_runner, ingester_runner = asyncio.run(resolve()) assert isinstance(hook_runner, OciHookRunner) - assert isinstance(source_runner, OciSourceRunner) + assert isinstance(ingester_runner, OciIngesterRunner) # --------------------------------------------------------------------------- diff --git a/server/tests/unit/domain/__init__.py b/server/tests/unit/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/deposition/test_convention_service_v2.py b/server/tests/unit/domain/deposition/test_convention_service_v2.py index fb88cdc..9663e22 100644 --- a/server/tests/unit/domain/deposition/test_convention_service_v2.py +++ b/server/tests/unit/domain/deposition/test_convention_service_v2.py @@ -14,7 +14,7 @@ OciConfig, TableFeatureSpec, ) -from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import Domain, SchemaSRN @@ -60,8 +60,8 @@ def _make_hook_def(name: str = "detect_pockets") -> HookDefinition: ) -def _make_source_def() -> SourceDefinition: - return SourceDefinition( +def _make_ingester_def() -> IngesterDefinition: + return IngesterDefinition( image="osa-sources/rcsb-pdb:latest", digest="sha256:abc123", config={"email": "test@example.com", "batch_size": 100}, @@ -129,31 +129,31 @@ async def test_convention_references_created_schema_srn(self): assert result.schema_srn == schema_srn @pytest.mark.asyncio - async def test_convention_saves_source_definition(self): + async def test_convention_saves_ingester_definition(self): service = _make_service() - source = _make_source_def() + ingester = _make_ingester_def() result = await service.create_convention( - title="With Source", + title="With Ingester", version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), - source=source, + ingester=ingester, ) - assert result.source is not None - assert result.source.image == "osa-sources/rcsb-pdb:latest" - assert result.source.digest == "sha256:abc123" - assert result.source.config == {"email": "test@example.com", "batch_size": 100} + assert result.ingester is not None + assert result.ingester.image == "osa-sources/rcsb-pdb:latest" + assert result.ingester.digest == "sha256:abc123" + assert result.ingester.config == {"email": "test@example.com", "batch_size": 100} @pytest.mark.asyncio - async def test_convention_source_defaults_to_none(self): + async def test_convention_ingester_defaults_to_none(self): service = _make_service() result = await service.create_convention( - title="No Source", + title="No Ingester", version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), ) - assert result.source is None + assert result.ingester is None @pytest.mark.asyncio async def test_convention_with_hooks_emits_hooks_in_event(self): @@ -182,7 +182,7 @@ async def test_create_convention_emits_convention_registered(self): version="1.0.0", schema=_make_field_defs(), file_requirements=_make_file_reqs(), - source=_make_source_def(), + ingester=_make_ingester_def(), ) outbox.append.assert_called_once() emitted = outbox.append.call_args[0][0] diff --git a/server/tests/unit/domain/deposition/test_create_deposition_from_source.py b/server/tests/unit/domain/deposition/test_create_deposition_from_source.py deleted file mode 100644 index 44e1e29..0000000 --- a/server/tests/unit/domain/deposition/test_create_deposition_from_source.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Unit tests for CreateDepositionFromSource event handler. - -Tests for User Story 3: Cross-domain decoupling. -""" - -from unittest.mock import AsyncMock, MagicMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.handler.create_deposition_from_source import ( - CreateDepositionFromSource, -) -from osa.domain.deposition.model.aggregate import Deposition -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") - - -def _make_dep_srn() -> DepositionSRN: - return DepositionSRN.parse("urn:osa:localhost:dep:test-dep") - - -def _make_event() -> SourceRecordReady: - return SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "4HHB", "title": "Hemoglobin"}, - file_paths=["4HHB/structure.cif"], - source_id="4HHB", - staging_dir="/tmp/staging/run-123", - ) - - -class TestCreateDepositionFromSource: - @pytest.mark.asyncio - async def test_creates_deposition_and_submits(self): - """Handler creates deposition, updates metadata, moves files, and submits.""" - dep = MagicMock(spec=Deposition) - dep.srn = _make_dep_srn() - - deposition_service = AsyncMock() - deposition_service.create.return_value = dep - file_storage = AsyncMock() - - handler = CreateDepositionFromSource( - deposition_service=deposition_service, - file_storage=file_storage, - ) - event = _make_event() - await handler.handle(event) - - # Creates deposition - deposition_service.create.assert_called_once() - create_kwargs = deposition_service.create.call_args[1] - assert create_kwargs["convention_srn"] == event.convention_srn - - # Updates metadata - deposition_service.update_metadata.assert_called_once_with( - srn=dep.srn, - metadata=event.metadata, - ) - - # Moves files - file_storage.move_source_files_to_deposition.assert_called_once() - - # Submits - deposition_service.submit.assert_called_once_with(srn=dep.srn) - - @pytest.mark.asyncio - async def test_uses_system_user_id(self): - """Handler creates deposition with SYSTEM_USER_ID.""" - from osa.domain.auth.model.value import SYSTEM_USER_ID - - dep = MagicMock(spec=Deposition) - dep.srn = _make_dep_srn() - - deposition_service = AsyncMock() - deposition_service.create.return_value = dep - file_storage = AsyncMock() - - handler = CreateDepositionFromSource( - deposition_service=deposition_service, - file_storage=file_storage, - ) - await handler.handle(_make_event()) - - create_kwargs = deposition_service.create.call_args[1] - assert create_kwargs["owner_id"] == SYSTEM_USER_ID diff --git a/server/tests/unit/domain/feature/test_insert_record_features.py b/server/tests/unit/domain/feature/test_insert_record_features.py index 114404a..dbd489e 100644 --- a/server/tests/unit/domain/feature/test_insert_record_features.py +++ b/server/tests/unit/domain/feature/test_insert_record_features.py @@ -9,7 +9,7 @@ from osa.domain.feature.service.feature import FeatureService from osa.domain.record.event.record_published import RecordPublished from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import DepositionSource, HarvestSource +from osa.domain.shared.model.source import DepositionSource, IngestSource from osa.domain.shared.model.srn import ( ConventionSRN, RecordSRN, @@ -198,34 +198,34 @@ async def test_no_features_is_noop(self): feature_store.insert_features.assert_not_called() -class TestInsertRecordFeaturesHarvestSource: - """US2: InsertRecordFeatures works identically for harvest-sourced records.""" +class TestInsertRecordFeaturesIngestSource: + """US2: InsertRecordFeatures works identically for ingest-sourced records.""" @pytest.mark.asyncio - async def test_harvest_source_uses_source_fields(self): + async def test_ingest_source_uses_source_fields(self): """Handler uses source type and id from event regardless of source type.""" feature_service = AsyncMock() storage = MagicMock() - storage.get_hook_output_root.return_value = "/fake/harvest/dir" + storage.get_hook_output_root.return_value = "/fake/ingest/dir" handler = _make_handler(feature_service=feature_service, feature_storage=storage) event = RecordPublished( id=EventId(uuid4()), record_srn=_make_record_srn(), - source=HarvestSource( + source=IngestSource( id="run-123-pdb-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ), - metadata={"title": "Harvested"}, + metadata={"title": "Ingested"}, convention_srn=_make_conv_srn(), expected_features=["pocket_detect"], ) await handler.handle(event) - storage.get_hook_output_root.assert_called_once_with("harvest", "run-123-pdb-456") + storage.get_hook_output_root.assert_called_once_with("ingest", "run-123-pdb-456") feature_service.insert_features_for_record.assert_called_once_with( - hook_output_dir="/fake/harvest/dir", + hook_output_dir="/fake/ingest/dir", record_srn=str(_make_record_srn()), expected_features=["pocket_detect"], ) diff --git a/server/tests/unit/domain/ingest/__init__.py b/server/tests/unit/domain/ingest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/ingest/test_ingest_run.py b/server/tests/unit/domain/ingest/test_ingest_run.py new file mode 100644 index 0000000..b1e607b --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingest_run.py @@ -0,0 +1,144 @@ +"""T015/T017: Unit tests for IngestRun aggregate — status transitions, completion, counters.""" + +from datetime import UTC, datetime + +import pytest + +from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.shared.error import InvalidStateError + + +def _make_run(**overrides) -> IngestRun: + defaults = { + "srn": "urn:osa:localhost:ing:test-run", + "convention_srn": "urn:osa:localhost:conv:test-conv@1.0.0", + "status": IngestStatus.PENDING, + "started_at": datetime.now(UTC), + } + defaults.update(overrides) + return IngestRun(**defaults) + + +class TestStatusTransitions: + def test_pending_to_running(self) -> None: + run = _make_run() + run.mark_running() + assert run.status == IngestStatus.RUNNING + + def test_running_to_completed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.transition_to(IngestStatus.COMPLETED) + assert run.status == IngestStatus.COMPLETED + + def test_running_to_failed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.mark_failed(datetime.now(UTC)) + assert run.status == IngestStatus.FAILED + assert run.completed_at is not None + + def test_pending_to_failed(self) -> None: + run = _make_run() + run.mark_failed(datetime.now(UTC)) + assert run.status == IngestStatus.FAILED + + def test_completed_to_running_rejected(self) -> None: + run = _make_run(status=IngestStatus.COMPLETED) + with pytest.raises(InvalidStateError, match="Cannot transition"): + run.transition_to(IngestStatus.RUNNING) + + def test_failed_to_running_rejected(self) -> None: + run = _make_run(status=IngestStatus.FAILED) + with pytest.raises(InvalidStateError, match="Cannot transition"): + run.transition_to(IngestStatus.RUNNING) + + def test_completed_to_completed_rejected(self) -> None: + run = _make_run(status=IngestStatus.COMPLETED) + with pytest.raises(InvalidStateError): + run.transition_to(IngestStatus.COMPLETED) + + +class TestCompletionCondition: + def test_not_complete_when_source_not_finished(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + source_finished=False, + batches_sourced=3, + batches_completed=3, + ) + assert not run.is_complete + + def test_not_complete_when_batches_pending(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + source_finished=True, + batches_sourced=3, + batches_completed=2, + ) + assert not run.is_complete + + def test_complete_when_all_batches_done(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + source_finished=True, + batches_sourced=3, + batches_completed=3, + ) + assert run.is_complete + + def test_check_completion_transitions_status(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + source_finished=True, + batches_sourced=2, + batches_completed=2, + ) + now = datetime.now(UTC) + completed = run.check_completion(now) + assert completed is True + assert run.status == IngestStatus.COMPLETED + assert run.completed_at == now + + def test_check_completion_noop_when_not_complete(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + source_finished=True, + batches_sourced=3, + batches_completed=2, + ) + completed = run.check_completion(datetime.now(UTC)) + assert completed is False + assert run.status == IngestStatus.RUNNING + + +class TestCounterIncrements: + def test_increment_batches_sourced(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.increment_batches_sourced() + assert run.batches_sourced == 1 + + def test_record_batch_completed(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.record_batch_completed(published_count=42) + assert run.batches_completed == 1 + assert run.published_count == 42 + + def test_multiple_batch_completions(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + run.record_batch_completed(published_count=100) + run.record_batch_completed(published_count=50) + assert run.batches_completed == 2 + assert run.published_count == 150 + + def test_mark_source_finished(self) -> None: + run = _make_run(status=IngestStatus.RUNNING) + assert not run.source_finished + run.mark_source_finished() + assert run.source_finished + + def test_batch_size_default(self) -> None: + run = _make_run() + assert run.batch_size == 1000 + + def test_custom_batch_size(self) -> None: + run = _make_run(batch_size=500) + assert run.batch_size == 500 diff --git a/server/tests/unit/domain/ingest/test_ingest_service.py b/server/tests/unit/domain/ingest/test_ingest_service.py new file mode 100644 index 0000000..4742ec6 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingest_service.py @@ -0,0 +1,112 @@ +"""T022: Unit tests for IngestService.start_ingest.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.error import ConflictError, NotFoundError +from osa.domain.shared.model.source import IngesterDefinition +from osa.domain.shared.model.srn import Domain + + +def _make_convention(*, has_ingester: bool = True): + conv = MagicMock() + conv.srn = "urn:osa:localhost:conv:test-conv@1.0.0" + conv.ingester = ( + IngesterDefinition( + image="ghcr.io/example/ingester:v1", + digest="sha256:abc123", + ) + if has_ingester + else None + ) + return conv + + +def _make_service( + *, + convention=None, + running_ingest=None, + convention_not_found: bool = False, +) -> IngestService: + ingest_repo = AsyncMock() + ingest_repo.get_running_for_convention.return_value = running_ingest + ingest_repo.save = AsyncMock() + + convention_service = AsyncMock() + if convention_not_found: + convention_service.get_convention.side_effect = NotFoundError("Convention not found") + else: + convention_service.get_convention.return_value = convention or _make_convention() + + outbox = AsyncMock() + + return IngestService( + ingest_repo=ingest_repo, + convention_service=convention_service, + outbox=outbox, + node_domain=Domain("localhost"), + ) + + +class TestStartIngest: + @pytest.mark.asyncio + async def test_creates_pending_ingest(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + assert run.status == IngestStatus.PENDING + assert run.convention_srn == "urn:osa:localhost:conv:test-conv@1.0.0" + assert run.batch_size == 1000 + + @pytest.mark.asyncio + async def test_saves_and_emits_event(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + service.ingest_repo.save.assert_called_once() + service.outbox.append.assert_called_once() + + # Verify the event is IngestStarted + event = service.outbox.append.call_args[0][0] + assert event.__class__.__name__ == "IngestStarted" + assert event.ingest_run_srn == run.srn + assert event.convention_srn == run.convention_srn + + @pytest.mark.asyncio + async def test_custom_batch_size(self) -> None: + service = _make_service() + run = await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + batch_size=500, + ) + assert run.batch_size == 500 + + @pytest.mark.asyncio + async def test_rejects_convention_not_found(self) -> None: + service = _make_service(convention_not_found=True) + with pytest.raises(NotFoundError): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:nonexistent@1.0.0", + ) + + @pytest.mark.asyncio + async def test_rejects_no_ingester_configured(self) -> None: + service = _make_service(convention=_make_convention(has_ingester=False)) + with pytest.raises(NotFoundError, match="No ingester configured"): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) + + @pytest.mark.asyncio + async def test_rejects_ingest_already_running(self) -> None: + existing = MagicMock() + service = _make_service(running_ingest=existing) + with pytest.raises(ConflictError, match="already running"): + await service.start_ingest( + convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", + ) diff --git a/server/tests/unit/domain/record/test_record_service.py b/server/tests/unit/domain/record/test_record_service.py index 2400d82..9dff7c1 100644 --- a/server/tests/unit/domain/record/test_record_service.py +++ b/server/tests/unit/domain/record/test_record_service.py @@ -11,7 +11,7 @@ from osa.domain.record.service.record import RecordService from osa.domain.shared.model.source import ( DepositionSource, - HarvestSource, + IngestSource, ) from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN, Domain, LocalId from osa.domain.shared.outbox import Outbox @@ -124,24 +124,24 @@ async def test_publish_record_creates_version_1( assert record.srn.version.root == 1 -class TestRecordServiceHarvestSource: - """US2: Verify harvest-sourced records publish correctly.""" +class TestRecordServiceIngestSource: + """US2: Verify ingest-sourced records publish correctly.""" @pytest.mark.asyncio - async def test_publish_with_harvest_source( + async def test_publish_with_ingest_source( self, mock_record_repo: RecordRepository, mock_outbox: Outbox, node_domain: Domain, ): - """HarvestSource draft produces correct Record + RecordPublished event.""" + """IngestSource draft produces correct Record + RecordPublished event.""" draft = RecordDraft( - source=HarvestSource( + source=IngestSource( id="run-123-pdb-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ), - metadata={"title": "Harvested Protein"}, + metadata={"title": "Ingested Protein"}, convention_srn=_make_conv_srn(), expected_features=["pocket_detect"], ) @@ -155,12 +155,12 @@ async def test_publish_with_harvest_source( record = await service.publish_record(draft) - assert record.source.type == "harvest" + assert record.source.type == "ingest" assert record.source.upstream_source == "pdb" assert record.convention_srn == _make_conv_srn() mock_record_repo.save.assert_called_once() event = mock_outbox.append.call_args[0][0] assert isinstance(event, RecordPublished) - assert event.source.type == "harvest" + assert event.source.type == "ingest" assert event.expected_features == ["pocket_detect"] diff --git a/server/tests/unit/domain/shared/test_record_source.py b/server/tests/unit/domain/shared/test_record_source.py index ff71708..c1fe20d 100644 --- a/server/tests/unit/domain/shared/test_record_source.py +++ b/server/tests/unit/domain/shared/test_record_source.py @@ -5,7 +5,7 @@ from osa.domain.shared.model.source import ( DepositionSource, - HarvestSource, + IngestSource, RecordSource, ) @@ -27,31 +27,31 @@ def test_serialization_roundtrip(self): assert restored == src -class TestHarvestSource: - def test_type_is_harvest(self): - src = HarvestSource( +class TestIngestSource: + def test_type_is_ingest(self): + src = IngestSource( id="run-123-source-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ) - assert src.type == "harvest" + assert src.type == "ingest" - def test_requires_harvest_run_srn(self): + def test_requires_ingest_run_srn(self): with pytest.raises(ValidationError): - HarvestSource(id="run-123", upstream_source="pdb") + IngestSource(id="run-123", upstream_source="pdb") def test_requires_upstream_source(self): with pytest.raises(ValidationError): - HarvestSource(id="run-123", harvest_run_srn="urn:osa:localhost:val:run123") + IngestSource(id="run-123", ingest_run_srn="urn:osa:localhost:val:run123") def test_serialization_roundtrip(self): - src = HarvestSource( + src = IngestSource( id="run-123-source-456", - harvest_run_srn="urn:osa:localhost:val:run123", + ingest_run_srn="urn:osa:localhost:val:run123", upstream_source="pdb", ) data = src.model_dump() - restored = HarvestSource.model_validate(data) + restored = IngestSource.model_validate(data) assert restored == src @@ -61,16 +61,16 @@ def test_deserializes_deposition(self): src = adapter.validate_python({"type": "deposition", "id": "dep-abc"}) assert isinstance(src, DepositionSource) - def test_deserializes_harvest(self): + def test_deserializes_ingest(self): data = { - "type": "harvest", + "type": "ingest", "id": "run-123", - "harvest_run_srn": "urn:osa:localhost:val:run1", + "ingest_run_srn": "urn:osa:localhost:val:run1", "upstream_source": "geo", } adapter = TypeAdapter(RecordSource) src = adapter.validate_python(data) - assert isinstance(src, HarvestSource) + assert isinstance(src, IngestSource) assert src.upstream_source == "geo" def test_rejects_unknown_type(self): @@ -81,12 +81,12 @@ def test_rejects_unknown_type(self): def test_json_roundtrip(self): """Serialize to JSON and back via the union type.""" adapter = TypeAdapter(RecordSource) - src = HarvestSource( + src = IngestSource( id="run-1", - harvest_run_srn="urn:osa:localhost:val:run1", + ingest_run_srn="urn:osa:localhost:val:run1", upstream_source="pdb", ) json_str = adapter.dump_json(src) restored = adapter.validate_json(json_str) - assert isinstance(restored, HarvestSource) + assert isinstance(restored, IngestSource) assert restored == src diff --git a/server/tests/unit/domain/source/test_source_record_ready.py b/server/tests/unit/domain/source/test_source_record_ready.py deleted file mode 100644 index 296d820..0000000 --- a/server/tests/unit/domain/source/test_source_record_ready.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Unit tests for SourceRecordReady event. - -Tests for User Story 3: Cross-domain decoupling — source domain. -""" - -from uuid import uuid4 - -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") - - -class TestSourceRecordReady: - def test_creation_with_all_fields(self): - """SourceRecordReady carries all required fields.""" - event = SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "4HHB", "title": "Hemoglobin"}, - file_paths=["4HHB/structure.cif"], - source_id="4HHB", - staging_dir="/tmp/staging/run-123", - ) - - assert event.convention_srn == _make_conv_srn() - assert event.metadata == {"pdb_id": "4HHB", "title": "Hemoglobin"} - assert event.file_paths == ["4HHB/structure.cif"] - assert event.source_id == "4HHB" - assert event.staging_dir == "/tmp/staging/run-123" - - def test_serialization_roundtrip(self): - """SourceRecordReady serializes and deserializes correctly.""" - event = SourceRecordReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - metadata={"pdb_id": "1CRN"}, - file_paths=["1CRN/data.cif", "1CRN/meta.json"], - source_id="1CRN", - staging_dir="/tmp/staging/run-456", - ) - - data = event.model_dump() - restored = SourceRecordReady.model_validate(data) - - assert restored.convention_srn == event.convention_srn - assert restored.metadata == event.metadata - assert restored.file_paths == event.file_paths - assert restored.source_id == event.source_id - assert restored.staging_dir == event.staging_dir - - def test_registered_in_event_registry(self): - """SourceRecordReady should be auto-registered in Event._registry.""" - from osa.domain.shared.event import Event - - assert "SourceRecordReady" in Event._registry diff --git a/server/tests/unit/domain/source/test_source_service.py b/server/tests/unit/domain/source/test_source_service.py deleted file mode 100644 index 46301dc..0000000 --- a/server/tests/unit/domain/source/test_source_service.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Unit tests for SourceService with OCI container model. - -Updated for cross-domain decoupling: SourceService no longer depends on -DepositionService, ConventionRepository, or FileStoragePort. Instead it -uses SourceStoragePort and emits SourceRecordReady events per record. -""" - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceOutput -from osa.domain.source.service.source import SourceService - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-conv-12345678@1.0.0") - - -def _make_source_def() -> SourceDefinition: - return SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - config={"batch_size": 100}, - ) - - -@pytest.fixture -def mock_outbox() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_source_storage() -> MagicMock: - storage = MagicMock() - storage.get_source_staging_dir.return_value = Path("/tmp/staging") - storage.get_source_output_dir.return_value = Path("/tmp/output") - return storage - - -@pytest.fixture -def mock_source_runner() -> AsyncMock: - runner = AsyncMock() - runner.run.return_value = SourceOutput( - records=[ - {"source_id": "4HHB", "metadata": {"pdb_id": "4HHB", "title": "Hemoglobin"}}, - {"source_id": "1CRN", "metadata": {"pdb_id": "1CRN", "title": "Crambin"}}, - ], - session=None, - files_dir=Path("/tmp/staging"), - ) - return runner - - -class TestSourceService: - @pytest.mark.asyncio - async def test_run_source_emits_per_record_events( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - result = await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - assert result.record_count == 2 - - # 2 SourceRecordReady + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - first = mock_outbox.append.call_args_list[0][0][0] - second = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(first, SourceRecordReady) - assert isinstance(second, SourceRecordReady) - assert first.source_id == "4HHB" - assert second.source_id == "1CRN" - - @pytest.mark.asyncio - async def test_run_source_carries_staging_dir( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRecordReady) - assert event.staging_dir == str(Path("/tmp/staging")) - - @pytest.mark.asyncio - async def test_run_source_emits_completion_event( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - last = mock_outbox.append.call_args_list[-1][0][0] - assert isinstance(last, SourceRunCompleted) - assert last.record_count == 2 - assert last.convention_srn == _make_conv_srn() - assert last.is_final_chunk is True - - @pytest.mark.asyncio - async def test_run_source_emits_continuation_when_session( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - mock_source_runner.run.return_value = SourceOutput( - records=[{"source_id": "4HHB", "metadata": {"pdb_id": "4HHB"}}], - session={"cursor": "abc"}, - files_dir=Path("/tmp/staging"), - ) - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - # 1 SourceRecordReady + 1 SourceRequested continuation + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - continuation = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(continuation, SourceRequested) - assert continuation.session == {"cursor": "abc"} - - @pytest.mark.asyncio - async def test_run_source_final_when_session_but_zero_records( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """Source returns session but zero records -> treated as final chunk.""" - mock_source_runner.run.return_value = SourceOutput( - records=[], - session={"cursor": "x"}, - files_dir=Path("/tmp/staging"), - ) - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - # Only SourceRunCompleted, no continuation - assert mock_outbox.append.call_count == 1 - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRunCompleted) - assert event.is_final_chunk is True diff --git a/server/tests/unit/domain/source/test_source_service_decoupled.py b/server/tests/unit/domain/source/test_source_service_decoupled.py deleted file mode 100644 index 8e55af8..0000000 --- a/server/tests/unit/domain/source/test_source_service_decoupled.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Unit tests for decoupled SourceService. - -Tests for User Story 3: Cross-domain decoupling. -Verifies SourceService emits SourceRecordReady per record -instead of calling DepositionService directly. -""" - -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from osa.domain.shared.model.source import SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.event.source_record_ready import SourceRecordReady -from osa.domain.source.event.source_run_completed import SourceRunCompleted -from osa.domain.source.port.source_runner import SourceOutput -from osa.domain.source.service.source import SourceService - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-conv-12345678@1.0.0") - - -def _make_source_def() -> SourceDefinition: - return SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - config={"batch_size": 100}, - ) - - -@pytest.fixture -def mock_outbox() -> AsyncMock: - return AsyncMock() - - -@pytest.fixture -def mock_source_storage() -> MagicMock: - storage = MagicMock() - storage.get_source_staging_dir.return_value = Path("/tmp/staging") - storage.get_source_output_dir.return_value = Path("/tmp/output") - return storage - - -@pytest.fixture -def mock_source_runner() -> AsyncMock: - runner = AsyncMock() - runner.run.return_value = SourceOutput( - records=[ - {"source_id": "4HHB", "metadata": {"pdb_id": "4HHB", "title": "Hemoglobin"}}, - {"source_id": "1CRN", "metadata": {"pdb_id": "1CRN", "title": "Crambin"}}, - ], - session=None, - files_dir=Path("/tmp/staging"), - ) - return runner - - -class TestDecoupledSourceService: - @pytest.mark.asyncio - async def test_emits_source_record_ready_per_record( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """SourceService emits SourceRecordReady for each record.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - # 2 SourceRecordReady + 1 SourceRunCompleted = 3 events - assert mock_outbox.append.call_count == 3 - first = mock_outbox.append.call_args_list[0][0][0] - second = mock_outbox.append.call_args_list[1][0][0] - assert isinstance(first, SourceRecordReady) - assert isinstance(second, SourceRecordReady) - assert first.source_id == "4HHB" - assert second.source_id == "1CRN" - - @pytest.mark.asyncio - async def test_source_record_ready_carries_staging_dir( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """SourceRecordReady carries the staging_dir path.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - event = mock_outbox.append.call_args_list[0][0][0] - assert isinstance(event, SourceRecordReady) - assert event.staging_dir == str(Path("/tmp/staging")) - - @pytest.mark.asyncio - async def test_emits_completion_event( - self, mock_outbox, mock_source_storage, mock_source_runner - ): - """Still emits SourceRunCompleted after all records.""" - service = SourceService( - source_runner=mock_source_runner, - source_storage=mock_source_storage, - outbox=mock_outbox, - ) - await service.run_source( - convention_srn=_make_conv_srn(), - source=_make_source_def(), - ) - - last = mock_outbox.append.call_args_list[-1][0][0] - assert isinstance(last, SourceRunCompleted) - assert last.record_count == 2 - - @pytest.mark.asyncio - async def test_no_deposition_service_dependency(self): - """SourceService no longer depends on DepositionService.""" - import inspect - - sig = inspect.signature(SourceService.__init__) - param_names = list(sig.parameters.keys()) - assert "deposition_service" not in param_names - assert "convention_repo" not in param_names diff --git a/server/tests/unit/domain/source/test_trigger_initial_source_run.py b/server/tests/unit/domain/source/test_trigger_initial_source_run.py deleted file mode 100644 index 5791e62..0000000 --- a/server/tests/unit/domain/source/test_trigger_initial_source_run.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Unit tests for TriggerInitialSourceRun event handler. - -Tests for User Story 2: Convention Initialization Chain. -Verifies handler consumes ConventionReady (not ConventionRegistered). -""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.model.convention import Convention -from osa.domain.deposition.model.value import FileRequirements -from osa.domain.feature.event.convention_ready import ConventionReady -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import InitialRunConfig, SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.handler.trigger_initial_source_run import TriggerInitialSourceRun - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-deploy-conv@1.0.0") - - -def _make_convention(source: SourceDefinition | None = None) -> Convention: - return Convention( - srn=_make_conv_srn(), - title="Test Convention", - schema_srn=SchemaSRN.parse("urn:osa:localhost:schema:test-schema12345678@1.0.0"), - file_requirements=FileRequirements( - accepted_types=[".csv"], - min_count=1, - max_count=3, - max_file_size=1_000_000, - ), - source=source, - created_at=datetime.now(UTC), - ) - - -def _make_event() -> ConventionReady: - return ConventionReady( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - ) - - -class TestTriggerInitialSourceRun: - @pytest.mark.asyncio - async def test_emits_source_requested_when_initial_run_configured(self): - """Emits SourceRequested when convention has initial_run configured.""" - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=InitialRunConfig(limit=500), - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_called_once() - emitted = outbox.append.call_args[0][0] - assert isinstance(emitted, SourceRequested) - assert emitted.convention_srn == convention.srn - assert emitted.limit == 500 - - @pytest.mark.asyncio - async def test_no_event_when_source_has_no_initial_run(self): - """No SourceRequested when initial_run is None.""" - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=None, - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - @pytest.mark.asyncio - async def test_no_event_when_convention_has_no_source(self): - """No SourceRequested when convention has no source.""" - convention = _make_convention(source=None) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerInitialSourceRun( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - def test_handler_event_type_is_convention_ready(self): - """TriggerInitialSourceRun.__event_type__ should be ConventionReady.""" - assert TriggerInitialSourceRun.__event_type__ is ConventionReady diff --git a/server/tests/unit/domain/source/test_trigger_source_on_deploy.py b/server/tests/unit/domain/source/test_trigger_source_on_deploy.py deleted file mode 100644 index 63b16b3..0000000 --- a/server/tests/unit/domain/source/test_trigger_source_on_deploy.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for TriggerSourceOnDeploy event handler.""" - -from datetime import UTC, datetime -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.deposition.event.convention_registered import ConventionRegistered -from osa.domain.deposition.model.convention import Convention -from osa.domain.deposition.model.value import FileRequirements -from osa.domain.shared.event import EventId -from osa.domain.shared.model.source import InitialRunConfig, SourceDefinition -from osa.domain.shared.model.srn import ConventionSRN, SchemaSRN -from osa.domain.source.event.source_requested import SourceRequested -from osa.domain.source.handler.trigger_source_on_deploy import TriggerSourceOnDeploy - - -def _make_conv_srn() -> ConventionSRN: - return ConventionSRN.parse("urn:osa:localhost:conv:test-deploy-conv@1.0.0") - - -def _make_schema_srn() -> SchemaSRN: - return SchemaSRN.parse("urn:osa:localhost:schema:test-schema12345678@1.0.0") - - -def _make_file_reqs() -> FileRequirements: - return FileRequirements( - accepted_types=[".csv"], - min_count=1, - max_count=3, - max_file_size=1_000_000, - ) - - -def _make_convention( - source: SourceDefinition | None = None, -) -> Convention: - return Convention( - srn=_make_conv_srn(), - title="Test Convention", - schema_srn=_make_schema_srn(), - file_requirements=_make_file_reqs(), - source=source, - created_at=datetime.now(UTC), - ) - - -def _make_event() -> ConventionRegistered: - return ConventionRegistered( - id=EventId(uuid4()), - convention_srn=_make_conv_srn(), - ) - - -class TestTriggerSourceOnDeploy: - @pytest.mark.asyncio - async def test_emits_source_requested_when_initial_run_configured(self): - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=InitialRunConfig(limit=500), - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_called_once() - emitted = outbox.append.call_args[0][0] - assert isinstance(emitted, SourceRequested) - assert emitted.convention_srn == convention.srn - assert emitted.limit == 500 - - @pytest.mark.asyncio - async def test_no_event_when_source_has_no_initial_run(self): - source = SourceDefinition( - image="osa-sources/test:latest", - digest="sha256:abc123", - initial_run=None, - ) - convention = _make_convention(source=source) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() - - @pytest.mark.asyncio - async def test_no_event_when_convention_has_no_source(self): - convention = _make_convention(source=None) - - convention_service = AsyncMock() - convention_service.get_convention.return_value = convention - outbox = AsyncMock() - - handler = TriggerSourceOnDeploy( - convention_service=convention_service, - outbox=outbox, - ) - await handler.handle(_make_event()) - - outbox.append.assert_not_called() diff --git a/server/tests/unit/domain/validation/test_hook_runner.py b/server/tests/unit/domain/validation/test_hook_runner.py index fa1e1fa..5b7052c 100644 --- a/server/tests/unit/domain/validation/test_hook_runner.py +++ b/server/tests/unit/domain/validation/test_hook_runner.py @@ -6,55 +6,57 @@ from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs, HookRunner class TestHookInputs: def test_minimal_construction(self): inputs = HookInputs( - record_json={"srn": "urn:osa:localhost:rec:123"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", ) - assert inputs.record_json == {"srn": "urn:osa:localhost:rec:123"} - assert inputs.files_dir is None + assert inputs.records == [HookRecord(id="rec1", metadata={})] + assert inputs.files_dirs == {} assert inputs.config is None - def test_with_files_dir(self): - files = Path("/tmp/files") + def test_with_files_dirs(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", - files_dir=files, + files_dirs={"rec1": Path("/tmp/files")}, ) - assert inputs.files_dir == files + assert inputs.files_dirs == {"rec1": Path("/tmp/files")} def test_with_config(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", config={"r_min": 3.0, "threshold": 0.5}, ) assert inputs.config == {"r_min": 3.0, "threshold": 0.5} def test_full_construction(self): - files = Path("/tmp/data/files") inputs = HookInputs( - record_json={"srn": "urn:osa:localhost:rec:456", "name": "test"}, + records=[ + HookRecord(id="rec1", metadata={"name": "test"}), + HookRecord(id="rec2", metadata={"name": "test2"}), + ], run_id="localhost_test456", - files_dir=files, + files_dirs={"rec1": Path("/tmp/data/files/rec1")}, config={"key": "value"}, ) - assert inputs.record_json["name"] == "test" - assert inputs.files_dir == files + assert len(inputs.records) == 2 + assert inputs.records[0].metadata["name"] == "test" assert inputs.config == {"key": "value"} def test_is_frozen(self): inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="rec1", metadata={})], run_id="localhost_test123", ) with pytest.raises(AttributeError): - inputs.record_json = {} # type: ignore[misc] + inputs.records = [] # type: ignore[misc] def test_is_dataclass(self): """HookInputs is a frozen dataclass.""" diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index 1ce570e..55abf33 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -14,6 +14,7 @@ from osa.domain.shared.model.srn import DepositionSRN, Domain from osa.domain.validation.model import RunStatus from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.domain.validation.service.validation import ValidationService @@ -63,9 +64,9 @@ def _make_service( def _make_inputs() -> HookInputs: return HookInputs( - record_json={"srn": "urn:osa:localhost:dep:test123", "metadata": {"name": "test"}}, + records=[HookRecord(id="urn:osa:localhost:dep:test123", metadata={"name": "test"})], run_id="localhost_test123", - files_dir=Path("/tmp/staging/files"), + files_dirs={"urn:osa:localhost:dep:test123": Path("/tmp/staging/files")}, ) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 14a9b47..4802709 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -15,6 +15,7 @@ TableFeatureSpec, ) from osa.domain.validation.model.hook_result import HookStatus +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.k8s.runner import K8sHookRunner @@ -460,7 +461,7 @@ async def test_successful_run(self, tmp_path: Path): b'{"step":"Check","status":"completed","message":"OK"}\n' ) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, core_api, @@ -509,7 +510,7 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -574,7 +575,7 @@ async def test_oom_exit_137(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -633,7 +634,7 @@ async def test_nonzero_exit(self, tmp_path: Path): hook = _make_hook() work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" work_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -687,7 +688,7 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -723,7 +724,7 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -778,7 +779,7 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): work_dir = tmp_path / "depositions" / "localhost_abc" / "hooks" / "validate_dna" output_dir = work_dir / "output" output_dir.mkdir(parents=True) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -843,7 +844,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): runner._s3.get_object.return_value = ( b'{"step":"Validate","status":"rejected","message":"Missing atoms"}\n' ) - inputs = HookInputs(record_json={"srn": "test"}, run_id=_RUN_ID) + inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( batch_api, @@ -904,7 +905,7 @@ async def test_run_uses_run_id_from_inputs(self, tmp_path: Path): hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="my-real-run-id", ) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py similarity index 91% rename from server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py rename to server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py index 39f19ae..6df39d4 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_source_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py @@ -1,4 +1,4 @@ -"""Unit tests for K8sSourceRunner — Job spec differences, source lifecycle.""" +"""Unit tests for K8sIngesterRunner — Job spec differences, source lifecycle.""" from pathlib import Path from typing import Any @@ -8,10 +8,10 @@ from osa.config import K8sConfig from osa.domain.shared.error import ExternalServiceError -from osa.domain.shared.model.source import SourceDefinition, SourceLimits +from osa.domain.shared.model.source import IngesterDefinition, IngesterLimits from osa.domain.shared.model.srn import ConventionSRN -from osa.domain.source.port.source_runner import SourceInputs -from osa.infrastructure.k8s.source_runner import K8sSourceRunner +from osa.domain.shared.port.ingester_runner import IngesterInputs +from osa.infrastructure.k8s.ingester_runner import K8sIngesterRunner _CONV_SRN = ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") @@ -23,12 +23,12 @@ def _make_source( memory: str = "4g", cpu: str = "2.0", config: dict[str, Any] | None = None, -) -> SourceDefinition: - return SourceDefinition( +) -> IngesterDefinition: + return IngesterDefinition( image=image, digest=digest, config=config, - limits=SourceLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), + limits=IngesterLimits(timeout_seconds=timeout, memory=memory, cpu=cpu), ) @@ -53,10 +53,10 @@ def _make_s3_mock() -> AsyncMock: return s3 -def _make_runner(config: K8sConfig | None = None) -> K8sSourceRunner: +def _make_runner(config: K8sConfig | None = None) -> K8sIngesterRunner: api_client = MagicMock() s3 = _make_s3_mock() - return K8sSourceRunner(api_client=api_client, config=config or _make_config(), s3=s3) + return K8sIngesterRunner(api_client=api_client, config=config or _make_config(), s3=s3) # --------------------------------------------------------------------------- @@ -138,7 +138,7 @@ def test_env_vars(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn=_CONV_SRN, limit=100, offset=50), + inputs=IngesterInputs(convention_srn=_CONV_SRN, limit=100, offset=50), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} @@ -158,7 +158,7 @@ def test_since_env_var(self): source, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), - inputs=SourceInputs(convention_srn=_CONV_SRN, since=since), + inputs=IngesterInputs(convention_srn=_CONV_SRN, since=since), ) env = spec.spec.template.spec.containers[0].env env_dict = {e.name: e.value for e in env} @@ -210,7 +210,7 @@ class TestSourceLifecycle: @pytest.mark.asyncio async def test_successful_run_with_records(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -258,7 +258,7 @@ async def s3_get(key: str) -> bytes: runner._s3.get_object.side_effect = s3_get - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) result = await runner._run_job( batch_api, core_api, @@ -276,7 +276,7 @@ async def s3_get(key: str) -> bytes: @pytest.mark.asyncio async def test_timeout_raises_external_service_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -309,7 +309,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Tt]imed out|[Dd]eadline"): await runner._run_job( @@ -324,7 +324,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): @pytest.mark.asyncio async def test_oom_raises_external_service_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -369,7 +369,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): work_dir.mkdir(parents=True) files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs(convention_srn=_CONV_SRN) + inputs = IngesterInputs(convention_srn=_CONV_SRN) with pytest.raises(ExternalServiceError, match="[Oo]OM"): await runner._run_job( @@ -383,7 +383,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): # --------------------------------------------------------------------------- -# Identity threading from SourceInputs +# Identity threading from IngesterInputs # --------------------------------------------------------------------------- @@ -395,7 +395,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): from unittest.mock import patch config = _make_config(data_mount_path=str(tmp_path)) - runner = K8sSourceRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) + runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() @@ -428,7 +428,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): files_dir = work_dir / "files" files_dir.mkdir(parents=True) - inputs = SourceInputs( + inputs = IngesterInputs( convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:my-conv@1.0.0") ) diff --git a/server/tests/unit/infrastructure/oci/__init__.py b/server/tests/unit/infrastructure/oci/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/infrastructure/test_hook_output_parsing.py b/server/tests/unit/infrastructure/test_hook_output_parsing.py new file mode 100644 index 0000000..f1d740e --- /dev/null +++ b/server/tests/unit/infrastructure/test_hook_output_parsing.py @@ -0,0 +1,132 @@ +"""T004: Unit tests for JSONL batch output parsing via FilesystemStorageAdapter.""" + +import json +from pathlib import Path + +import pytest + +from osa.infrastructure.persistence.adapter.storage import FilesystemStorageAdapter + + +def _write_jsonl(path: Path, lines: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + for line in lines: + f.write(json.dumps(line) + "\n") + + +def _hook_output_dir(base: Path, hook_name: str = "validate_dna") -> Path: + """Create the expected directory structure: {base}/hooks/{hook_name}/output/""" + d = base / "hooks" / hook_name / "output" + d.mkdir(parents=True, exist_ok=True) + return d + + +@pytest.fixture +def adapter(tmp_path: Path) -> FilesystemStorageAdapter: + return FilesystemStorageAdapter(str(tmp_path)) + + +HOOK = "validate_dna" + + +class TestReadBatchOutcomes: + """Parse features.jsonl, rejections.jsonl, errors.jsonl via storage adapter.""" + + @pytest.mark.anyio + async def test_single_line_features( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "features.jsonl", [{"id": "rec1", "features": [{"score": 0.9}]}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 1 + assert outcomes["rec1"].status == "passed" + assert outcomes["rec1"].features == [{"score": 0.9}] + + @pytest.mark.anyio + async def test_multi_line_features( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl( + output / "features.jsonl", + [ + {"id": "rec1", "features": [{"score": 0.9}]}, + {"id": "rec2", "features": [{"score": 0.7}]}, + ], + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 2 + assert outcomes["rec1"].status == "passed" + assert outcomes["rec2"].status == "passed" + + @pytest.mark.anyio + async def test_rejections(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl( + output / "rejections.jsonl", [{"id": "rec3", "reason": "Missing required field"}] + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes["rec3"].status == "rejected" + assert outcomes["rec3"].reason == "Missing required field" + + @pytest.mark.anyio + async def test_errors(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "errors.jsonl", [{"id": "rec4", "error": "OOM", "retryable": True}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes["rec4"].status == "errored" + assert outcomes["rec4"].error == "OOM" + + @pytest.mark.anyio + async def test_mixed_outcomes(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + output = _hook_output_dir(tmp_path) + _write_jsonl(output / "features.jsonl", [{"id": "a", "features": []}]) + _write_jsonl(output / "rejections.jsonl", [{"id": "b", "reason": "bad"}]) + _write_jsonl(output / "errors.jsonl", [{"id": "c", "error": "fail"}]) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 3 + assert outcomes["a"].status == "passed" + assert outcomes["b"].status == "rejected" + assert outcomes["c"].status == "errored" + + @pytest.mark.anyio + async def test_empty_directory(self, adapter: FilesystemStorageAdapter, tmp_path: Path) -> None: + _hook_output_dir(tmp_path) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert outcomes == {} + + @pytest.mark.anyio + async def test_malformed_json_line_skipped( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + (output / "features.jsonl").write_text( + '{"id": "ok", "features": []}\n' + "not valid json\n" + '{"id": "also_ok", "features": [{"x": 1}]}\n' + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 2 + assert "ok" in outcomes + assert "also_ok" in outcomes + + @pytest.mark.anyio + async def test_missing_id_field_skipped( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + output = _hook_output_dir(tmp_path) + (output / "features.jsonl").write_text( + '{"id": "ok", "features": []}\n{"features": [{"x": 1}]}\n' + ) + outcomes = await adapter.read_batch_outcomes(str(tmp_path), HOOK) + assert len(outcomes) == 1 + assert "ok" in outcomes + + @pytest.mark.anyio + async def test_nonexistent_output_dir( + self, adapter: FilesystemStorageAdapter, tmp_path: Path + ) -> None: + outcomes = await adapter.read_batch_outcomes(str(tmp_path / "nonexistent"), HOOK) + assert outcomes == {} diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index ae43290..3220f94 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -13,6 +13,7 @@ TableFeatureSpec, ) from osa.domain.validation.model.hook_result import HookStatus, ProgressEntry +from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs from osa.infrastructure.oci.runner import OciHookRunner from osa.infrastructure.runner_utils import ( @@ -199,7 +200,7 @@ async def test_successful_hook_returns_passed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -225,7 +226,7 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -248,7 +249,7 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -278,7 +279,7 @@ async def hang(): runner = OciHookRunner(docker=docker) hook = _make_hook(timeout=1) # 1 second timeout inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -305,7 +306,7 @@ async def test_rejection_via_progress(self, tmp_path: Path): work_dir.mkdir() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -335,7 +336,7 @@ async def test_security_hardening(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook(memory="4g", cpu="4.0") inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -369,7 +370,7 @@ async def test_env_vars_set(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) @@ -399,9 +400,9 @@ async def test_nested_bind_mounts(self, tmp_path: Path): files_dir = tmp_path / "files" files_dir.mkdir() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", - files_dir=files_dir, + files_dirs={"test": files_dir}, ) work_dir = tmp_path / "hook_work" @@ -413,7 +414,7 @@ async def test_nested_bind_mounts(self, tmp_path: Path): config = call_args[0][0] if call_args[0] else call_args[1].get("config", {}) binds = config["HostConfig"]["Binds"] - # Should have 3 binds: input:ro, output:rw, files:ro + # Should have 3 binds: input:ro, output:rw, files/{id}:ro assert len(binds) == 3 # input/ and output/ are sibling dirs under work_dir @@ -421,7 +422,7 @@ async def test_nested_bind_mounts(self, tmp_path: Path): out_bind = [b for b in binds if b.endswith(":/osa/out:rw")][0] assert str(work_dir / "input") in in_bind assert str(work_dir / "output") in out_bind - assert any(b.endswith(":/osa/in/files:ro") for b in binds) + assert any(":/osa/files/test:ro" in b for b in binds) @pytest.mark.asyncio async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): @@ -435,9 +436,8 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", - files_dir=None, ) output_dir = tmp_path / "output" @@ -449,7 +449,10 @@ async def test_no_files_bind_when_no_files_dir(self, tmp_path: Path): config = call_args[0][0] if call_args[0] else call_args[1].get("config", {}) binds = config["HostConfig"]["Binds"] - assert len(binds) == 2 # staging + output only + # staging + output + empty files base dir + assert len(binds) == 3 + # No per-record file mounts + assert not any(":/osa/files/" in b and ":ro" in b for b in binds if b.count("/") > 3) @pytest.mark.asyncio async def test_container_deleted_on_failure(self, tmp_path: Path): @@ -464,7 +467,7 @@ async def test_container_deleted_on_failure(self, tmp_path: Path): runner = OciHookRunner(docker=docker) hook = _make_hook() inputs = HookInputs( - record_json={"srn": "test"}, + records=[HookRecord(id="test", metadata={})], run_id="test-run", ) From 5b790de9c292afabc94a496f746e061e165174fb Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 25 Mar 2026 22:26:46 +0000 Subject: [PATCH 2/9] refactor: rename source to ingester throughout codebase Rename variables, functions, comments, and documentation from "source" terminology to "ingester" to better reflect the actual purpose of these components in the data ingestion pipeline. --- server/Justfile | 1 + server/osa/domain/deposition/port/storage.py | 5 +- .../domain/feature/event/convention_ready.py | 3 +- .../domain/ingest/handler/publish_batch.py | 18 ++--- server/osa/domain/ingest/handler/run_hooks.py | 10 +-- .../osa/domain/ingest/handler/run_ingester.py | 6 +- server/osa/domain/shared/event.py | 13 ++-- .../osa/domain/shared/port/ingester_runner.py | 2 +- .../osa/domain/validation/port/hook_runner.py | 2 +- server/osa/infrastructure/k8s/__init__.py | 2 +- .../osa/infrastructure/k8s/ingester_runner.py | 52 +++++++------- .../osa/infrastructure/oci/ingester_runner.py | 44 ++++++------ .../persistence/adapter/storage.py | 2 +- server/osa/infrastructure/runner_utils.py | 2 +- server/osa/infrastructure/s3/storage.py | 6 +- .../k8s/test_k8s_ingester_runner.py | 70 +++++++++---------- 16 files changed, 119 insertions(+), 119 deletions(-) diff --git a/server/Justfile b/server/Justfile index 5310a3f..eb93bc9 100644 --- a/server/Justfile +++ b/server/Justfile @@ -36,6 +36,7 @@ test-cov: # Run linter and type checker lint: + @just fix uv run ruff check osa uv run ty check osa diff --git a/server/osa/domain/deposition/port/storage.py b/server/osa/domain/deposition/port/storage.py index ad8f7be..7476c48 100644 --- a/server/osa/domain/deposition/port/storage.py +++ b/server/osa/domain/deposition/port/storage.py @@ -11,9 +11,8 @@ class FileStoragePort(Port, Protocol): """Storage operations scoped to the deposition domain. - Hook output, hook features, and source staging methods have been - moved to their respective domain ports (HookStoragePort, - FeatureStoragePort, SourceStoragePort). + Hook output and hook features methods have been moved to their + respective domain ports (HookStoragePort, FeatureStoragePort). """ @abstractmethod diff --git a/server/osa/domain/feature/event/convention_ready.py b/server/osa/domain/feature/event/convention_ready.py index 1c7d7d0..42b627e 100644 --- a/server/osa/domain/feature/event/convention_ready.py +++ b/server/osa/domain/feature/event/convention_ready.py @@ -7,8 +7,7 @@ class ConventionReady(Event): """Emitted when feature tables have been created for a convention. - Downstream handlers (e.g. TriggerInitialSourceRun) react to this - to kick off initial source runs, knowing that feature tables are ready. + Downstream handlers react to this knowing that feature tables are ready. """ id: EventId diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index 3d6637c..bf5d987 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -46,17 +46,17 @@ async def handle(self, event: HookBatchCompleted) -> None: ConventionSRN.parse(ingest_run.convention_srn) ) - # Read source records from batch dir + # Read ingester records from batch dir ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) batch_dir = ingest_dir / "batches" / str(event.batch_index) - source_records = _read_source_records(batch_dir / "source" / "records.jsonl") + ingester_records = _read_ingester_records(batch_dir / "ingester" / "records.jsonl") # Read hook outcomes for all hooks expected_features = [h.name for h in convention.hooks] # Determine which records passed all hooks passed_records = _get_passed_records( - source_records=source_records, + ingester_records=ingester_records, batch_dir=batch_dir, hooks=expected_features, feature_storage=self.feature_storage, @@ -134,8 +134,8 @@ async def handle(self, event: HookBatchCompleted) -> None: ) -def _read_source_records(records_file: Path) -> list[dict]: - """Read source records from JSONL file.""" +def _read_ingester_records(records_file: Path) -> list[dict]: + """Read ingester records from JSONL file.""" records: list[dict] = [] if not records_file.exists(): return records @@ -146,19 +146,19 @@ def _read_source_records(records_file: Path) -> list[dict]: try: records.append(json.loads(line)) except json.JSONDecodeError: - logger.warning("Skipping malformed source record line") + logger.warning("Skipping malformed ingester record line") return records def _get_passed_records( - source_records: list[dict], + ingester_records: list[dict], batch_dir: Path, hooks: list[str], feature_storage: FeatureStoragePort, ) -> list[dict]: """Determine which records passed all hooks by reading features.jsonl.""" if not hooks: - return source_records + return ingester_records # For simplicity, read the last hook's features.jsonl to get passed IDs # (hooks run sequentially, so the last hook's output has the final set) @@ -180,7 +180,7 @@ def _get_passed_records( except json.JSONDecodeError: logger.warning("Skipping malformed features.jsonl line") - return [r for r in source_records if r.get("source_id", r.get("id", "")) in passed_ids] + return [r for r in ingester_records if r.get("source_id", r.get("id", "")) in passed_ids] def _safe_srn(srn: str) -> str: diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index bc143c4..3b02136 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -39,11 +39,11 @@ async def handle(self, event: IngesterBatchReady) -> None: ConventionSRN.parse(ingest_run.convention_srn) ) - # Read records from batch source dir + # Read records from batch ingester dir ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) batch_dir = ingest_dir / "batches" / str(event.batch_index) - source_dir = batch_dir / "source" - records_file = source_dir / "records.jsonl" + ingester_dir = batch_dir / "ingester" + records_file = ingester_dir / "records.jsonl" records: list[dict] = [] if records_file.exists(): @@ -60,8 +60,8 @@ async def handle(self, event: IngesterBatchReady) -> None: if not records: logger.warning("No records in batch %d for %s", event.batch_index, event.ingest_run_srn) - # Build files_dirs from source files - files_base = source_dir / "files" + # Build files_dirs from ingester files + files_base = ingester_dir / "files" files_dirs: dict[str, Path] = {} if files_base.exists(): for record in records: diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 111dd21..2d44546 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -50,7 +50,7 @@ async def handle(self, event: IngestStarted) -> None: # Prepare scratch directory ingest_dir = self.paths.data_dir / "ingests" / self._safe_srn(event.ingest_run_srn) - batch_dir = ingest_dir / "batches" / str(batch_index) / "source" + batch_dir = ingest_dir / "batches" / str(batch_index) / "ingester" batch_dir.mkdir(parents=True, exist_ok=True) # Load session state for continuation @@ -70,13 +70,13 @@ async def handle(self, event: IngestStarted) -> None: files_dir.mkdir(parents=True, exist_ok=True) output = await self.ingester_runner.run( - source=convention.ingester, + ingester=convention.ingester, inputs=inputs, files_dir=files_dir, work_dir=batch_dir, ) - # Write records.jsonl to batch source dir + # Write records.jsonl to batch ingester dir records_file = batch_dir / "records.jsonl" with records_file.open("w") as f: for record in output.records: diff --git a/server/osa/domain/shared/event.py b/server/osa/domain/shared/event.py index 7c3d565..b6ed59b 100644 --- a/server/osa/domain/shared/event.py +++ b/server/osa/domain/shared/event.py @@ -209,14 +209,11 @@ class EventHandler(Generic[E], metaclass=_EventHandlerMeta): __claim_timeout__: Seconds before claim considered stale (default: 300.0) Example (single event): - class TriggerInitialSourceRun(EventHandler[ServerStarted]): - _config: Config - _outbox: Outbox - - async def handle(self, event: ServerStarted) -> None: - for source in self._config.sources: - if source.initial_run and source.initial_run.enabled: - await self._outbox.append(SourceRequested(...)) + class HandleRecordPublished(EventHandler[RecordPublished]): + _service: IndexingService + + async def handle(self, event: RecordPublished) -> None: + await self._service.index(event.record_srn) Example (batch processing): class VectorIndexHandler(EventHandler[IndexRecord]): diff --git a/server/osa/domain/shared/port/ingester_runner.py b/server/osa/domain/shared/port/ingester_runner.py index 4a54742..c4bb7a0 100644 --- a/server/osa/domain/shared/port/ingester_runner.py +++ b/server/osa/domain/shared/port/ingester_runner.py @@ -41,7 +41,7 @@ class IngesterRunner(Protocol): async def run( self, - source: IngesterDefinition, + ingester: IngesterDefinition, inputs: IngesterInputs, files_dir: Path, work_dir: Path, diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index d1695f9..55ff484 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -16,7 +16,7 @@ class HookInputs: """Inputs to pass to a hook container. Uses the unified batch contract: records is a list of HookRecord - (1 for depositions, N for harvests). + (1 for depositions, N for ingests). files_dirs maps record ID → directory containing that record's files. """ diff --git a/server/osa/infrastructure/k8s/__init__.py b/server/osa/infrastructure/k8s/__init__.py index 2511c39..36bc8cf 100644 --- a/server/osa/infrastructure/k8s/__init__.py +++ b/server/osa/infrastructure/k8s/__init__.py @@ -1,6 +1,6 @@ """Kubernetes runner infrastructure. kubernetes-asyncio is an optional dependency. Modules that require it -(di.py, runner.py, source_runner.py, health.py) perform lazy imports +(di.py, runner.py, ingester_runner.py, health.py) perform lazy imports and raise ConfigurationError if the package is not installed. """ diff --git a/server/osa/infrastructure/k8s/ingester_runner.py b/server/osa/infrastructure/k8s/ingester_runner.py index 066920c..fa29ec1 100644 --- a/server/osa/infrastructure/k8s/ingester_runner.py +++ b/server/osa/infrastructure/k8s/ingester_runner.py @@ -54,7 +54,7 @@ def _s3_prefix(self, work_dir: Path, subdir: str) -> str: async def run( self, - source: IngesterDefinition, + ingester: IngesterDefinition, inputs: IngesterInputs, files_dir: Path, work_dir: Path, @@ -74,8 +74,8 @@ async def run( # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") - if inputs.config or source.config: - config = {**(source.config or {}), **(inputs.config or {})} + if inputs.config or ingester.config: + config = {**(ingester.config or {}), **(inputs.config or {})} await self._s3.put_object(f"{input_prefix}/config.json", json.dumps(config)) if inputs.session: @@ -84,7 +84,7 @@ async def run( return await self._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -95,21 +95,21 @@ async def _run_job( self, batch_api: BatchV1Api, core_api: CoreV1Api, - source: IngesterDefinition, + ingester: IngesterDefinition, inputs: IngesterInputs, work_dir: Path, files_dir: Path, *, convention_srn: ConventionSRN | None = None, ) -> IngesterOutput: - """Core Job lifecycle for source execution.""" + """Core Job lifecycle for ingester execution.""" namespace = self._config.namespace job_name_to_watch = None try: # Check for existing Jobs existing = await self._check_existing_job( - batch_api, namespace, convention_srn, source.digest + batch_api, namespace, convention_srn, ingester.digest ) if existing == "succeeded": @@ -119,7 +119,7 @@ async def _run_job( job_name_to_watch = existing.split(":", 1)[1] else: spec = self._build_job_spec( - source, + ingester, work_dir=work_dir, files_dir=files_dir, inputs=inputs, @@ -129,11 +129,11 @@ async def _run_job( await batch_api.create_namespaced_job(namespace, spec) logger.info( - "Created K8s source Job", + "Created K8s ingester Job", extra={ "job_name": job_name_to_watch, "namespace": namespace, - "image": f"{source.image}@{source.digest}", + "image": f"{ingester.image}@{ingester.digest}", }, ) @@ -145,7 +145,7 @@ async def _run_job( batch_api, job_name_to_watch, namespace, - timeout_seconds=source.limits.timeout_seconds + 30, + timeout_seconds=ingester.limits.timeout_seconds + 30, ) if result == "succeeded": @@ -161,7 +161,7 @@ async def _run_job( return output # Failed — diagnose and raise - await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, source, result) + await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, ingester, result) # unreachable but satisfies type checker raise ExternalServiceError("Source failed") @@ -187,7 +187,7 @@ async def _check_existing_job( convention_srn: ConventionSRN | None, digest: str = "", ) -> str | None: - label_parts = ["osa.io/role=source"] + label_parts = ["osa.io/role=ingester"] if convention_srn is not None: label_parts.append(f"osa.io/convention={label_value(convention_srn)}") if digest: @@ -208,7 +208,7 @@ async def _check_existing_job( def _build_job_spec( self, - source: IngesterDefinition, + ingester: IngesterDefinition, *, work_dir: Path, files_dir: Path, @@ -234,15 +234,15 @@ def _build_job_spec( V1VolumeMount, ) - name = job_name("source", "src", str(convention_srn) if convention_srn else "unknown") + name = job_name("ingester", "ing", str(convention_srn) if convention_srn else "unknown") relative_work = self._relative_path(work_dir) input_subpath = f"{relative_work}/input" output_subpath = f"{relative_work}/output" relative_files = self._relative_path(files_dir) labels: dict[str, str] = { - "osa.io/role": "source", - "osa.io/digest": sanitize_label(source.digest), + "osa.io/role": "ingester", + "osa.io/digest": sanitize_label(ingester.digest), } if convention_srn is not None: labels["osa.io/convention"] = label_value(convention_srn) @@ -278,13 +278,13 @@ def _build_job_spec( ] container = V1Container( - name="source", - image=f"{source.image}@{source.digest}", + name="ingester", + image=f"{ingester.image}@{ingester.digest}", env=env, resources=V1ResourceRequirements( limits={ - "memory": to_k8s_quantity(source.limits.memory), - "cpu": source.limits.cpu, + "memory": to_k8s_quantity(ingester.limits.memory), + "cpu": ingester.limits.cpu, }, ), security_context=V1SecurityContext( @@ -320,7 +320,7 @@ def _build_job_spec( metadata=V1ObjectMeta(name=name, namespace=self._config.namespace, labels=labels), spec=V1JobSpec( backoff_limit=0, - active_deadline_seconds=SCHEDULING_TIMEOUT + source.limits.timeout_seconds, + active_deadline_seconds=SCHEDULING_TIMEOUT + ingester.limits.timeout_seconds, ttl_seconds_after_finished=self._config.job_ttl_seconds, template=V1PodTemplateSpec( metadata=V1ObjectMeta(labels=labels), @@ -420,12 +420,14 @@ async def _diagnose_and_raise( core_api: CoreV1Api, job_name: str, namespace: str, - source: IngesterDefinition, + ingester: IngesterDefinition, failure_info: str, ) -> None: """Determine failure reason and raise appropriate error.""" if "DeadlineExceeded" in failure_info: - raise ExternalServiceError(f"Source timed out after {source.limits.timeout_seconds}s") + raise ExternalServiceError( + f"Ingester timed out after {ingester.limits.timeout_seconds}s" + ) try: label_selector = f"job-name={job_name}" @@ -457,4 +459,4 @@ async def _cleanup_job(self, batch_api: BatchV1Api, job_name: str, namespace: st except Exception as exc: if getattr(exc, "status", None) == 404: return - logger.warning("Failed to clean up K8s source Job", extra={"job_name": job_name}) + logger.warning("Failed to clean up K8s ingester Job", extra={"job_name": job_name}) diff --git a/server/osa/infrastructure/oci/ingester_runner.py b/server/osa/infrastructure/oci/ingester_runner.py index ebdb35f..bb926ee 100644 --- a/server/osa/infrastructure/oci/ingester_runner.py +++ b/server/osa/infrastructure/oci/ingester_runner.py @@ -21,12 +21,12 @@ class OciIngesterRunner(IngesterRunner): - """Executes sources in OCI containers via aiodocker. + """Executes ingesters in OCI containers via aiodocker. Key differences from OciHookRunner: - - Network access enabled (sources call upstream APIs) + - Network access enabled (ingesters call upstream APIs) - Three bind mounts: $OSA_IN (ro), $OSA_OUT (rw), $OSA_FILES (rw) - - No ReadonlyRootfs (sources may need writable FS for pip cache, etc.) + - No ReadonlyRootfs (ingesters may need writable FS for pip cache, etc.) - Higher default limits (3600s timeout, 4g memory) - Output is records.jsonl (line-delimited JSON), not features.json @@ -47,12 +47,12 @@ def __init__( async def run( self, - source: IngesterDefinition, + ingester: IngesterDefinition, inputs: IngesterInputs, files_dir: Path, work_dir: Path, ) -> IngesterOutput: - timeout = source.limits.timeout_seconds + timeout = ingester.limits.timeout_seconds from shutil import rmtree @@ -68,8 +68,8 @@ def _force_remove(func, path, exc): container_output = work_dir / "output" container_output.mkdir(parents=True, exist_ok=True) try: - if inputs.config or source.config: - config = {**(source.config or {}), **(inputs.config or {})} + if inputs.config or ingester.config: + config = {**(ingester.config or {}), **(inputs.config or {})} (staging_dir / "config.json").write_text(json.dumps(config)) if inputs.session: @@ -80,9 +80,9 @@ def _force_remove(func, path, exc): try: async def _resolve_and_run(): - image_ref = await self._resolve_image(source.image, source.digest) + image_ref = await self._resolve_image(ingester.image, ingester.digest) return await self._run_container( - image_ref, staging_dir, files_dir, container_output, source, inputs + image_ref, staging_dir, files_dir, container_output, ingester, inputs ) result = await asyncio.wait_for( @@ -93,12 +93,12 @@ async def _resolve_and_run(): except asyncio.TimeoutError: duration = time.monotonic() - start_time logfire.error( - "Source timed out", - image=source.image, + "Ingester timed out", + image=ingester.image, timeout=timeout, duration=duration, ) - raise ExternalServiceError(f"Source timed out after {timeout}s") + raise ExternalServiceError(f"Ingester timed out after {timeout}s") finally: rmtree(staging_dir, onexc=_force_remove) @@ -108,7 +108,7 @@ async def _run_container( staging_dir: Path, files_dir: Path, output_dir: Path, - source: IngesterDefinition, + ingester: IngesterDefinition, inputs: IngesterInputs, ) -> IngesterOutput: container = None @@ -137,11 +137,11 @@ async def _run_container( "Env": env, "HostConfig": { "Binds": binds, - "Memory": parse_memory(source.limits.memory), - "MemorySwap": parse_memory(source.limits.memory), - "NanoCpus": int(float(source.limits.cpu) * 1e9), + "Memory": parse_memory(ingester.limits.memory), + "MemorySwap": parse_memory(ingester.limits.memory), + "NanoCpus": int(float(ingester.limits.cpu) * 1e9), # No NetworkMode: "none" — sources need network access - # No ReadonlyRootfs — sources may need writable FS + # No ReadonlyRootfs — ingesters may need writable FS "CapDrop": ["ALL"], "SecurityOpt": ["no-new-privileges"], "PidsLimit": 256, @@ -159,19 +159,21 @@ async def _run_container( oom_killed = inspect_data.get("State", {}).get("OOMKilled", False) if oom_killed: - raise ExternalServiceError("Source killed by OOM") + raise ExternalServiceError("Ingester killed by OOM") if exit_code != 0: logs = await container.log(stdout=True, stderr=True) logs_str = "".join(logs) if logs else "" - raise ExternalServiceError(f"Source exited with code {exit_code}: {logs_str[:500]}") + raise ExternalServiceError( + f"Ingester exited with code {exit_code}: {logs_str[:500]}" + ) records = parse_records_file(output_dir) session = parse_session_file(output_dir) return IngesterOutput(records=records, session=session, files_dir=files_dir) except aiodocker.DockerError as e: - logfire.error("Docker error running source", error=str(e)) + logfire.error("Docker error running ingester", error=str(e)) raise ExternalServiceError(f"Docker error: {e}") from e finally: if container is not None: @@ -210,6 +212,6 @@ async def _resolve_image(self, image: str, digest: str) -> str: pass # Pull from registry as last resort - logfire.info("Pulling source image", image=image) + logfire.info("Pulling ingester image", image=image) await self._docker.images.pull(image) return image diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index c08711a..e45b869 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -20,7 +20,7 @@ class FilesystemStorageAdapter(FileStoragePort): """Local filesystem adapter satisfying all domain storage ports. - Implements FileStoragePort (deposition files), SourceStoragePort, + Implements FileStoragePort (deposition files), HookStoragePort, and FeatureStoragePort via structural subtyping. """ diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py index 336a93d..bf58525 100644 --- a/server/osa/infrastructure/runner_utils.py +++ b/server/osa/infrastructure/runner_utils.py @@ -123,7 +123,7 @@ def relative_path(path: Path, data_mount_path: str) -> str: def parse_records_file(output_dir: Path) -> list[dict[str, Any]]: - """Parse records.jsonl from source output directory.""" + """Parse records.jsonl from ingester output directory.""" import logfire records: list[dict[str, Any]] = [] diff --git a/server/osa/infrastructure/s3/storage.py b/server/osa/infrastructure/s3/storage.py index 5b0dcf0..5b340fa 100644 --- a/server/osa/infrastructure/s3/storage.py +++ b/server/osa/infrastructure/s3/storage.py @@ -26,7 +26,7 @@ class S3StorageAdapter(FileStoragePort): """S3-backed adapter satisfying all domain storage ports. - Implements FileStoragePort, SourceStoragePort, HookStoragePort, + Implements FileStoragePort, HookStoragePort, and FeatureStoragePort via structural subtyping — same as FilesystemStorageAdapter but using S3 API instead of POSIX calls. @@ -116,7 +116,7 @@ async def delete_files_for_deposition( prefix = f"{self._dep_prefix(deposition_id)}/" await self._s3.delete_objects(prefix) - # ── SourceStoragePort ──────────────────────────────────────────── + # ── Ingester storage ────────────────────────────────────────────── def get_source_staging_dir(self, convention_srn: ConventionSRN, run_id: str) -> Path: """Return path for PVC subpath computation (no I/O).""" @@ -144,7 +144,7 @@ async def move_source_files_to_deposition( source_id: str, deposition_srn: DepositionSRN, ) -> None: - """S3 server-side copy from source staging to deposition files prefix.""" + """S3 server-side copy from ingester staging to deposition files prefix.""" source_prefix = f"{relative_path(staging_dir, self._data_mount_path)}/{source_id}/" dest_prefix = self._files_prefix(deposition_srn) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py index 6df39d4..398905e 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py @@ -1,4 +1,4 @@ -"""Unit tests for K8sIngesterRunner — Job spec differences, source lifecycle.""" +"""Unit tests for K8sIngesterRunner — Job spec differences, ingester lifecycle.""" from pathlib import Path from typing import Any @@ -16,8 +16,8 @@ _CONV_SRN = ConventionSRN.parse("urn:osa:localhost:conv:test@1.0.0") -def _make_source( - image: str = "ghcr.io/example/source:v1", +def _make_ingester( + image: str = "ghcr.io/example/ingester:v1", digest: str = "sha256:abc123", timeout: int = 3600, memory: str = "4g", @@ -68,9 +68,9 @@ class TestSourceJobSpec: def test_network_enabled(self): """Source Jobs have normal DNS policy (network access).""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -80,9 +80,9 @@ def test_network_enabled(self): def test_writable_rootfs(self): """Source containers do not have readOnlyRootFilesystem.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -92,23 +92,23 @@ def test_writable_rootfs(self): def test_higher_defaults(self): """Source Jobs use higher defaults (3600s, 4g).""" runner = _make_runner() - source = _make_source(timeout=3600, memory="4g") + ingester = _make_ingester(timeout=3600, memory="4g") spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) resources = spec.spec.template.spec.containers[0].resources assert resources.limits["memory"] == "4Gi" - # activeDeadlineSeconds = scheduling_timeout + source timeout + # activeDeadlineSeconds = scheduling_timeout + ingester timeout assert spec.spec.active_deadline_seconds == 120 + 3600 def test_three_volume_mounts(self): """Source Jobs have input, output, and files mounts.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -121,9 +121,9 @@ def test_three_volume_mounts(self): def test_files_mount_writable(self): """Source files mount is writable.""" runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) @@ -133,9 +133,9 @@ def test_files_mount_writable(self): def test_env_vars(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), inputs=IngesterInputs(convention_srn=_CONV_SRN, limit=100, offset=50), @@ -152,10 +152,10 @@ def test_since_env_var(self): from datetime import datetime, UTC runner = _make_runner() - source = _make_source() + ingester = _make_ingester() since = datetime(2026, 1, 1, tzinfo=UTC) spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), inputs=IngesterInputs(convention_srn=_CONV_SRN, since=since), @@ -164,35 +164,35 @@ def test_since_env_var(self): env_dict = {e.name: e.value for e in env} assert "OSA_SINCE" in env_dict - def test_source_role_label(self): + def test_ingester_role_label(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), ) labels = spec.spec.template.metadata.labels - assert labels["osa.io/role"] == "source" + assert labels["osa.io/role"] == "ingester" def test_human_readable_name(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), ) name = spec.metadata.name - assert name.startswith("osa-source-") + assert name.startswith("osa-ingester-") assert len(name) <= 63 def test_convention_srn_in_labels(self): runner = _make_runner() - source = _make_source() + ingester = _make_ingester() spec = runner._build_job_spec( - source, + ingester, work_dir=Path("/data/sources/localhost_conv1/staging/run1"), files_dir=Path("/data/sources/localhost_conv1/staging/run1/files"), convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:conv1@1.0.0"), @@ -239,7 +239,7 @@ async def test_successful_run_with_records(self, tmp_path: Path): completed_job.status.failed = None batch_api.read_namespaced_job.return_value = completed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" files_dir = work_dir / "files" @@ -262,7 +262,7 @@ async def s3_get(key: str) -> bytes: result = await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -304,7 +304,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): failed_job.status.failed = 1 batch_api.read_namespaced_job.return_value = failed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" work_dir.mkdir(parents=True) files_dir = work_dir / "files" @@ -315,7 +315,7 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -364,7 +364,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): core_api.list_namespaced_pod.side_effect = [pod_list, oom_pod_list] - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "localhost_conv1" / "staging" / "run1" work_dir.mkdir(parents=True) files_dir = work_dir / "files" @@ -375,7 +375,7 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): await runner._run_job( batch_api, core_api, - source, + ingester, inputs, work_dir, files_dir, @@ -421,7 +421,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): completed_job.status.failed = None batch_api.read_namespaced_job.return_value = completed_job - source = _make_source() + ingester = _make_ingester() work_dir = tmp_path / "sources" / "run1" output_dir = work_dir / "output" output_dir.mkdir(parents=True) @@ -436,7 +436,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), ): - await runner.run(source, inputs, files_dir, work_dir) + await runner.run(ingester, inputs, files_dir, work_dir) # Verify convention_srn from inputs ends up in the Job labels call_args = batch_api.create_namespaced_job.call_args From 9eaa3b276f8b5dd344ca2812531e719be1ac271c Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Wed, 25 Mar 2026 22:57:16 +0000 Subject: [PATCH 3/9] refactor: consolidate storage path logic into StorageLayout class Replace hardcoded path construction with centralized StorageLayout to eliminate duplication and provide single source of truth for directory structure across ingest handlers --- .../feature/handler/insert_batch_features.py | 39 ++++++--- .../domain/ingest/handler/publish_batch.py | 80 +++++++++++-------- server/osa/domain/ingest/handler/run_hooks.py | 21 +++-- .../osa/domain/ingest/handler/run_ingester.py | 15 ++-- server/osa/infrastructure/ingest/di.py | 8 +- server/osa/infrastructure/storage/__init__.py | 0 server/osa/infrastructure/storage/layout.py | 48 +++++++++++ 7 files changed, 143 insertions(+), 68 deletions(-) create mode 100644 server/osa/infrastructure/storage/__init__.py create mode 100644 server/osa/infrastructure/storage/layout.py diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py index e8785ea..d459de8 100644 --- a/server/osa/domain/feature/handler/insert_batch_features.py +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -6,6 +6,7 @@ from osa.domain.feature.service.feature import FeatureService from osa.domain.ingest.event.events import IngestBatchPublished from osa.domain.shared.event import EventHandler +from osa.infrastructure.storage.layout import StorageLayout logger = logging.getLogger(__name__) @@ -14,30 +15,44 @@ class InsertBatchFeatures(EventHandler[IngestBatchPublished]): """Reads hook outputs for an ingest batch and inserts features in bulk. Handles IngestBatchPublished (batch-level event) rather than - per-record RecordPublished. Shares core insertion logic with - InsertRecordFeatures via FeatureService. + per-record RecordPublished. Uses read_batch_outcomes to parse + the JSONL output format (not the single-record features.json). """ feature_service: FeatureService feature_storage: FeatureStoragePort + layout: StorageLayout async def handle(self, event: IngestBatchPublished) -> None: if not event.expected_features or not event.published_srns: return - hook_output_root = self.feature_storage.get_hook_output_root("ingest", event.ingest_run_srn) + batch_output_dir = str( + self.layout.ingest_batch_dir(event.ingest_run_srn, event.batch_index) + ) + + total_inserted = 0 + + for hook_name in event.expected_features: + # Read JSONL outcomes for this hook + outcomes = await self.feature_storage.read_batch_outcomes(batch_output_dir, hook_name) + + # Insert features for each published record that passed this hook + for record_id, outcome in outcomes.items(): + if outcome.status != "passed" or not outcome.features: + continue - # Read batch outcomes for each hook and insert features - for record_srn in event.published_srns: - await self.feature_service.insert_features_for_record( - hook_output_dir=hook_output_root, - record_srn=record_srn, - expected_features=event.expected_features, - ) + count = await self.feature_service.insert_features( + hook_name=hook_name, + record_srn=record_id, + rows=outcome.features, + ) + total_inserted += count logger.info( - "Inserted features for %d records in batch %d of %s", - len(event.published_srns), + "Inserted %d feature rows for batch %d of %s (%d hooks)", + total_inserted, event.batch_index, event.ingest_run_srn, + len(event.expected_features), ) diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index bf5d987..1f73def 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -22,7 +22,7 @@ from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox from osa.domain.feature.port.storage import FeatureStoragePort -from osa.util.paths import OSAPaths +from osa.infrastructure.storage.layout import StorageLayout logger = logging.getLogger(__name__) @@ -35,7 +35,7 @@ class PublishBatch(EventHandler[HookBatchCompleted]): record_service: RecordService feature_storage: FeatureStoragePort outbox: Outbox - paths: OSAPaths + layout: StorageLayout async def handle(self, event: HookBatchCompleted) -> None: ingest_run = await self.ingest_repo.get(event.ingest_run_srn) @@ -47,9 +47,11 @@ async def handle(self, event: HookBatchCompleted) -> None: ) # Read ingester records from batch dir - ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) - batch_dir = ingest_dir / "batches" / str(event.batch_index) - ingester_records = _read_ingester_records(batch_dir / "ingester" / "records.jsonl") + batch_dir = self.layout.ingest_batch_dir(event.ingest_run_srn, event.batch_index) + ingester_dir = self.layout.ingest_batch_ingester_dir( + event.ingest_run_srn, event.batch_index + ) + ingester_records = _read_ingester_records(ingester_dir / "records.jsonl") # Read hook outcomes for all hooks expected_features = [h.name for h in convention.hooks] @@ -62,6 +64,7 @@ async def handle(self, event: HookBatchCompleted) -> None: feature_storage=self.feature_storage, ) + published_count = 0 if not passed_records: logger.info("No passing records in batch %d", event.batch_index) else: @@ -82,16 +85,18 @@ async def handle(self, event: HookBatchCompleted) -> None: ) ) - # Bulk publish + # Bulk publish — ON CONFLICT DO NOTHING skips duplicates, + # so published may be shorter than drafts published = await self.record_service.bulk_publish(drafts) published_srns = [str(r.srn) for r in published] published_count = len(published) logger.info( - "Published %d records from batch %d of %s", + "Published %d records from batch %d of %s (%d duplicates skipped)", published_count, event.batch_index, event.ingest_run_srn, + len(drafts) - published_count, ) # Emit IngestBatchPublished for feature insertion @@ -108,11 +113,11 @@ async def handle(self, event: HookBatchCompleted) -> None: ) ) - # Update counters atomically — use published_count or 0 - count = len(passed_records) if passed_records else 0 + # Update counters atomically — use actual published_count (not passed_records) + # to avoid over-counting when ON CONFLICT DO NOTHING skips duplicates updated = await self.ingest_repo.increment_completed( event.ingest_run_srn, - published_count=count, + published_count=published_count, ) # Check completion condition @@ -156,32 +161,41 @@ def _get_passed_records( hooks: list[str], feature_storage: FeatureStoragePort, ) -> list[dict]: - """Determine which records passed all hooks by reading features.jsonl.""" + """Determine which records passed ALL hooks by intersecting features.jsonl across hooks. + + Each hook processes the full batch independently. A record must appear in + every hook's features.jsonl to be considered passed. Records rejected or + errored by any hook are excluded. + """ if not hooks: return ingester_records - # For simplicity, read the last hook's features.jsonl to get passed IDs - # (hooks run sequentially, so the last hook's output has the final set) - last_hook = hooks[-1] - features_file = batch_dir / "hooks" / last_hook / "output" / "features.jsonl" - if not features_file.exists(): - return [] + passed_ids: set[str] | None = None + + for hook_name in hooks: + features_file = batch_dir / "hooks" / hook_name / "output" / "features.jsonl" + if not features_file.exists(): + return [] # If any hook produced no features file, nothing passed + + hook_passed: set[str] = set() + for line in features_file.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + record_id = data.get("id") + if record_id: + hook_passed.add(record_id) + except json.JSONDecodeError: + logger.warning("Skipping malformed features.jsonl line in hook %s", hook_name) + + if passed_ids is None: + passed_ids = hook_passed + else: + passed_ids &= hook_passed - passed_ids: set[str] = set() - for line in features_file.open(): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - record_id = data.get("id") - if record_id: - passed_ids.add(record_id) - except json.JSONDecodeError: - logger.warning("Skipping malformed features.jsonl line") + if not passed_ids: + return [] return [r for r in ingester_records if r.get("source_id", r.get("id", "")) in passed_ids] - - -def _safe_srn(srn: str) -> str: - return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index 3b02136..eda1e1b 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -14,7 +14,7 @@ from osa.domain.shared.outbox import Outbox from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs, HookRunner -from osa.util.paths import OSAPaths +from osa.infrastructure.storage.layout import StorageLayout logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ class RunHooks(EventHandler[IngesterBatchReady]): convention_service: ConventionService hook_runner: HookRunner outbox: Outbox - paths: OSAPaths + layout: StorageLayout async def handle(self, event: IngesterBatchReady) -> None: ingest_run = await self.ingest_repo.get(event.ingest_run_srn) @@ -40,9 +40,9 @@ async def handle(self, event: IngesterBatchReady) -> None: ) # Read records from batch ingester dir - ingest_dir = self.paths.data_dir / "ingests" / _safe_srn(event.ingest_run_srn) - batch_dir = ingest_dir / "batches" / str(event.batch_index) - ingester_dir = batch_dir / "ingester" + ingester_dir = self.layout.ingest_batch_ingester_dir( + event.ingest_run_srn, event.batch_index + ) records_file = ingester_dir / "records.jsonl" records: list[dict] = [] @@ -72,7 +72,9 @@ async def handle(self, event: IngesterBatchReady) -> None: # Run each hook sequentially for hook in convention.hooks: - hook_output_dir = batch_dir / "hooks" / hook.name + hook_output_dir = self.layout.ingest_batch_hook_dir( + event.ingest_run_srn, event.batch_index, hook.name + ) hook_output_dir.mkdir(parents=True, exist_ok=True) inputs = HookInputs( @@ -83,7 +85,7 @@ async def handle(self, event: IngesterBatchReady) -> None: ) for r in records ], - run_id=f"{_safe_srn(event.ingest_run_srn)}_batch{event.batch_index}", + run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", files_dirs=files_dirs, config=None, ) @@ -105,8 +107,3 @@ async def handle(self, event: IngesterBatchReady) -> None: event.ingest_run_srn, len(records), ) - - -def _safe_srn(srn: str) -> str: - """Convert SRN to filesystem-safe string.""" - return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 2d44546..0299edb 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -13,7 +13,7 @@ from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterRunner -from osa.util.paths import OSAPaths +from osa.infrastructure.storage.layout import StorageLayout logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ class RunIngester(EventHandler[IngestStarted]): convention_service: ConventionService ingester_runner: IngesterRunner outbox: Outbox - paths: OSAPaths + layout: StorageLayout async def handle(self, event: IngestStarted) -> None: ingest_run = await self.ingest_repo.get(event.ingest_run_srn) @@ -49,12 +49,11 @@ async def handle(self, event: IngestStarted) -> None: batch_index = ingest_run.batches_sourced # Prepare scratch directory - ingest_dir = self.paths.data_dir / "ingests" / self._safe_srn(event.ingest_run_srn) - batch_dir = ingest_dir / "batches" / str(batch_index) / "ingester" + batch_dir = self.layout.ingest_batch_ingester_dir(event.ingest_run_srn, batch_index) batch_dir.mkdir(parents=True, exist_ok=True) # Load session state for continuation - session_file = ingest_dir / "session.json" + session_file = self.layout.ingest_session_file(event.ingest_run_srn) session = None if session_file.exists(): session = json.loads(session_file.read_text()) @@ -84,6 +83,7 @@ async def handle(self, event: IngestStarted) -> None: # Save session for continuation if output.session: + session_file.parent.mkdir(parents=True, exist_ok=True) session_file.write_text(json.dumps(output.session)) has_more = output.session is not None and len(output.records) > 0 @@ -122,8 +122,3 @@ async def handle(self, event: IngestStarted) -> None: batch_size=ingest_run.batch_size, ) ) - - @staticmethod - def _safe_srn(srn: str) -> str: - """Convert SRN to filesystem-safe string.""" - return srn.replace(":", "_").replace("@", "_") diff --git a/server/osa/infrastructure/ingest/di.py b/server/osa/infrastructure/ingest/di.py index 5aef81b..045ac82 100644 --- a/server/osa/infrastructure/ingest/di.py +++ b/server/osa/infrastructure/ingest/di.py @@ -10,12 +10,18 @@ from osa.domain.shared.model.srn import Domain from osa.domain.shared.outbox import Outbox from osa.infrastructure.persistence.repository.ingest import PostgresIngestRunRepository +from osa.infrastructure.storage.layout import StorageLayout from osa.util.di.base import Provider from osa.util.di.scope import Scope +from osa.util.paths import OSAPaths class IngestProvider(Provider): - """Provides IngestService, IngestRunRepository, and StartIngestHandler.""" + """Provides IngestService, IngestRunRepository, StorageLayout, and StartIngestHandler.""" + + @provide(scope=Scope.APP) + def get_storage_layout(self, paths: OSAPaths) -> StorageLayout: + return StorageLayout(paths.data_dir) @provide(scope=Scope.UOW) def get_ingest_repo(self, session: AsyncSession) -> IngestRunRepository: diff --git a/server/osa/infrastructure/storage/__init__.py b/server/osa/infrastructure/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/infrastructure/storage/layout.py b/server/osa/infrastructure/storage/layout.py new file mode 100644 index 0000000..a00e734 --- /dev/null +++ b/server/osa/infrastructure/storage/layout.py @@ -0,0 +1,48 @@ +"""Storage layout — single source of truth for directory structure. + +Composable path methods that define where data lives on disk/S3. +Storage adapters and runners consume this instead of hardcoding paths. + +See #106 for the full consolidation plan. Currently covers ingest paths only; +deposition paths will be migrated here in a follow-up. +""" + +from pathlib import Path + + +def _safe_srn(srn: str) -> str: + """Convert an SRN to a filesystem-safe string.""" + return srn.replace(":", "_").replace("@", "_") + + +class StorageLayout: + """Computes storage paths relative to a data root. + + All methods return Path objects. Storage adapters prefix with their + own root (filesystem base_path or S3 key prefix). + """ + + def __init__(self, data_dir: Path) -> None: + self._data_dir = data_dir + + # ── Ingest paths ───────────────────────────────────────────────── + + def ingest_run_dir(self, ingest_run_srn: str) -> Path: + """Root directory for an ingest run.""" + return self._data_dir / "ingests" / _safe_srn(ingest_run_srn) + + def ingest_batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + """Directory for a specific batch within an ingest run.""" + return self.ingest_run_dir(ingest_run_srn) / "batches" / str(batch_index) + + def ingest_batch_ingester_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + """Ingester output directory (records.jsonl, files/) for a batch.""" + return self.ingest_batch_dir(ingest_run_srn, batch_index) / "ingester" + + def ingest_batch_hook_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: + """Hook output directory for a batch.""" + return self.ingest_batch_dir(ingest_run_srn, batch_index) / "hooks" / hook_name + + def ingest_session_file(self, ingest_run_srn: str) -> Path: + """Session state file for ingester continuation.""" + return self.ingest_run_dir(ingest_run_srn) / "session.json" From 080e431a99a327e2f2cd1065b51cbd8a2c55754a Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 26 Mar 2026 13:13:03 +0000 Subject: [PATCH 4/9] feat: add configurable record limit and concurrency for ingest runs - Add record_limit column to ingest_runs table for limiting total records - Add OSA_BASE_URL environment variable support in Docker configs - Implement concurrency control for event handlers via __concurrency__ - Fix feature insertion to use record SRN instead of upstream ID - Increase ingester memory limit from 512m to 1g - Update config validation to require OSA_BASE_URL for localhost --- deploy/docker-compose.dev.yml | 1 + deploy/docker-compose.yml | 1 + .../migrations/versions/add_harvest_runs.py | 1 + server/osa/config.py | 26 +++++++++---- .../feature/handler/insert_batch_features.py | 17 ++++++-- .../osa/domain/ingest/command/start_ingest.py | 4 ++ server/osa/domain/ingest/event/events.py | 1 + .../domain/ingest/handler/publish_batch.py | 6 +++ server/osa/domain/ingest/handler/run_hooks.py | 3 +- .../osa/domain/ingest/handler/run_ingester.py | 22 ++++++++++- server/osa/domain/ingest/model/ingest_run.py | 1 + server/osa/domain/ingest/service/ingest.py | 2 + server/osa/domain/shared/event.py | 1 + server/osa/domain/shared/model/source.py | 2 +- server/osa/infrastructure/event/worker.py | 39 ++++++++++++++----- .../persistence/repository/ingest.py | 2 + .../persistence/repository/record.py | 8 ++-- .../osa/infrastructure/persistence/tables.py | 1 + .../unit/application/test_app_factory.py | 1 + server/tests/unit/config/test_config.py | 1 + server/tests/unit/config/test_paths_config.py | 7 ++++ 21 files changed, 120 insertions(+), 27 deletions(-) diff --git a/deploy/docker-compose.dev.yml b/deploy/docker-compose.dev.yml index ce5db40..d5ff008 100644 --- a/deploy/docker-compose.dev.yml +++ b/deploy/docker-compose.dev.yml @@ -20,6 +20,7 @@ services: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data OSA_CONFIG_FILE: /app/osa.yaml + OSA_BASE_URL: http://localhost:8000 OSA_LOGGING__LEVEL: ${LOG_LEVEL:-DEBUG} WATCHFILES_FORCE_POLLING: "true" entrypoint: [] diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 172633f..3be0e9a 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -21,6 +21,7 @@ services: environment: OSA_DATABASE__URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-osa}@db:5432/${POSTGRES_DB:-osa} OSA_DATA_DIR: /data + OSA_BASE_URL: ${OSA_BASE_URL:-http://localhost:8000} OSA_LOGGING__LEVEL: ${LOG_LEVEL:-INFO} OSA_AUTH__JWT__SECRET: ${JWT_SECRET:-change-me-in-production-must-be-32-chars-long} depends_on: diff --git a/server/migrations/versions/add_harvest_runs.py b/server/migrations/versions/add_harvest_runs.py index 9c3564e..0d2f61c 100644 --- a/server/migrations/versions/add_harvest_runs.py +++ b/server/migrations/versions/add_harvest_runs.py @@ -66,6 +66,7 @@ def upgrade() -> None: nullable=False, server_default=sa.text("1000"), ), + sa.Column("record_limit", sa.Integer(), nullable=True), sa.Column("started_at", sa.DateTime(timezone=True), nullable=False), sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), sa.CheckConstraint( diff --git a/server/osa/config.py b/server/osa/config.py index bf399c8..1cbea64 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -225,7 +225,8 @@ class Config(BaseSettings): name: str = "Open Science Archive" version: str = "0.0.1" description: str = "An open platform for depositing scientific data" - domain: str = "localhost" # Node domain for SRN construction + domain: str = "localhost" # Node identity for SRN construction (DNS name) + base_url: str = "" # Public URL where users reach the server (e.g. http://localhost:8000) # These are BaseModel, so env_nested_delimiter handles their env vars frontend: Frontend = Frontend() @@ -243,15 +244,26 @@ class Config(BaseSettings): "env_nested_delimiter": "__", # Allows OSA_DATABASE__URL override } - @property - def base_url(self) -> str: - """Public base URL derived from domain. HTTPS unless localhost.""" - scheme = "http" if self.domain == "localhost" else "https" - return f"{scheme}://{self.domain}" + @model_validator(mode="after") + def derive_base_url(self) -> Self: + """Derive base_url from domain if not explicitly set. + + For non-localhost domains, HTTPS on port 443 is assumed. + For localhost, base_url must be set explicitly (port matters). + """ + if self.base_url: + return self + if self.domain == "localhost": + raise ValueError( + "OSA_BASE_URL is required when domain is localhost " + "(e.g. OSA_BASE_URL=http://localhost:8000)" + ) + self.base_url = f"https://{self.domain}" + return self @model_validator(mode="after") def derive_frontend_url(self) -> Self: - """Derive frontend URL from domain if still the default localhost value.""" + """Derive frontend URL from base_url if still the default localhost value.""" if self.frontend.url == "http://localhost:3000": self.frontend = Frontend(url=self.base_url) return self diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py index d459de8..a548405 100644 --- a/server/osa/domain/feature/handler/insert_batch_features.py +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -37,14 +37,25 @@ async def handle(self, event: IngestBatchPublished) -> None: # Read JSONL outcomes for this hook outcomes = await self.feature_storage.read_batch_outcomes(batch_output_dir, hook_name) - # Insert features for each published record that passed this hook - for record_id, outcome in outcomes.items(): + # Insert features for each published record that passed this hook. + # Map upstream source ID → published record SRN so features + # are keyed by the record SRN, not the upstream ID. + for upstream_id, outcome in outcomes.items(): if outcome.status != "passed" or not outcome.features: continue + record_srn = event.upstream_to_record_srn.get(upstream_id) + if not record_srn: + logger.warning( + "No record SRN mapping for upstream ID %s in batch %d", + upstream_id, + event.batch_index, + ) + continue + count = await self.feature_service.insert_features( hook_name=hook_name, - record_srn=record_id, + record_srn=record_srn, rows=outcome.features, ) total_inserted += count diff --git a/server/osa/domain/ingest/command/start_ingest.py b/server/osa/domain/ingest/command/start_ingest.py index bef2fca..1b65bbd 100644 --- a/server/osa/domain/ingest/command/start_ingest.py +++ b/server/osa/domain/ingest/command/start_ingest.py @@ -10,6 +10,7 @@ class StartIngest(Command): convention_srn: str batch_size: int = 1000 + limit: int | None = None # Max total records to ingest (None = unlimited) class IngestRunCreated(Result): @@ -26,14 +27,17 @@ class StartIngestHandler(CommandHandler[StartIngest, IngestRunCreated]): __auth__ = at_least(Role.ADMIN) + from osa.domain.auth.model.principal import Principal from osa.domain.ingest.service.ingest import IngestService + principal: Principal service: IngestService async def run(self, cmd: StartIngest) -> IngestRunCreated: ingest_run = await self.service.start_ingest( convention_srn=cmd.convention_srn, batch_size=cmd.batch_size, + limit=cmd.limit, ) return IngestRunCreated( srn=ingest_run.srn, diff --git a/server/osa/domain/ingest/event/events.py b/server/osa/domain/ingest/event/events.py index f47f3fb..ee0eaab 100644 --- a/server/osa/domain/ingest/event/events.py +++ b/server/osa/domain/ingest/event/events.py @@ -49,6 +49,7 @@ class IngestBatchPublished(Event): published_srns: list[str] published_count: int expected_features: list[str] + upstream_to_record_srn: dict[str, str] # upstream source ID → published record SRN class IngestCompleted(Event): diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index 1f73def..2c60445 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -91,6 +91,11 @@ async def handle(self, event: HookBatchCompleted) -> None: published_srns = [str(r.srn) for r in published] published_count = len(published) + # Build upstream ID → record SRN mapping for feature insertion + upstream_to_record_srn: dict[str, str] = {} + for record in published: + upstream_to_record_srn[record.source.upstream_source] = str(record.srn) + logger.info( "Published %d records from batch %d of %s (%d duplicates skipped)", published_count, @@ -110,6 +115,7 @@ async def handle(self, event: HookBatchCompleted) -> None: published_srns=published_srns, published_count=published_count, expected_features=expected_features, + upstream_to_record_srn=upstream_to_record_srn, ) ) diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index eda1e1b..3cf824f 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -23,6 +23,7 @@ class RunHooks(EventHandler[IngesterBatchReady]): """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" __claim_timeout__ = 3600.0 # Hook runs can be long + __concurrency__ = 4 # Run hook containers for multiple batches in parallel ingest_repo: IngestRunRepository convention_service: ConventionService @@ -102,7 +103,7 @@ async def handle(self, event: IngesterBatchReady) -> None: ) logger.info( - "Hooks completed for batch %d of %s (%d records)", + "Hooks completed for batch %d of %s (%d records processed)", event.batch_index, event.ingest_run_srn, len(records), diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 0299edb..7eb8daf 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -58,11 +58,25 @@ async def handle(self, event: IngestStarted) -> None: if session_file.exists(): session = json.loads(session_file.read_text()) + # Compute effective limit for this batch + # If a total limit is set, don't request more than remaining + effective_batch_limit = ingest_run.batch_size + if ingest_run.limit is not None: + sourced_so_far = ingest_run.batches_sourced * ingest_run.batch_size + remaining = ingest_run.limit - sourced_so_far + if remaining <= 0: + # Already sourced enough — mark finished + await self.ingest_repo.increment_batches_sourced( + event.ingest_run_srn, set_source_finished=True + ) + return + effective_batch_limit = min(ingest_run.batch_size, remaining) + # Run ingester container inputs = IngesterInputs( convention_srn=convention.srn, config=convention.ingester.config, - limit=ingest_run.batch_size, + limit=effective_batch_limit, session=session, ) files_dir = batch_dir / "files" @@ -88,6 +102,12 @@ async def handle(self, event: IngestStarted) -> None: has_more = output.session is not None and len(output.records) > 0 + # If total limit is set, check whether we've sourced enough + if has_more and ingest_run.limit is not None: + total_sourced = (ingest_run.batches_sourced + 1) * ingest_run.batch_size + if total_sourced >= ingest_run.limit: + has_more = False + # Update counters atomically await self.ingest_repo.increment_batches_sourced( event.ingest_run_srn, diff --git a/server/osa/domain/ingest/model/ingest_run.py b/server/osa/domain/ingest/model/ingest_run.py index 33e7f6f..02642b2 100644 --- a/server/osa/domain/ingest/model/ingest_run.py +++ b/server/osa/domain/ingest/model/ingest_run.py @@ -37,6 +37,7 @@ class IngestRun(Aggregate): batches_completed: int = 0 published_count: int = 0 batch_size: int = 1000 + limit: int | None = None # Max total records (None = unlimited) started_at: datetime completed_at: datetime | None = None diff --git a/server/osa/domain/ingest/service/ingest.py b/server/osa/domain/ingest/service/ingest.py index 89c032c..2a58419 100644 --- a/server/osa/domain/ingest/service/ingest.py +++ b/server/osa/domain/ingest/service/ingest.py @@ -29,6 +29,7 @@ async def start_ingest( self, convention_srn: str, batch_size: int = 1000, + limit: int | None = None, ) -> IngestRun: """Create an ingest run for a convention. @@ -61,6 +62,7 @@ async def start_ingest( convention_srn=convention_srn, status=IngestStatus.PENDING, batch_size=batch_size, + limit=limit, started_at=now, ) diff --git a/server/osa/domain/shared/event.py b/server/osa/domain/shared/event.py index b6ed59b..0fc19b3 100644 --- a/server/osa/domain/shared/event.py +++ b/server/osa/domain/shared/event.py @@ -233,6 +233,7 @@ async def handle_batch(self, events: list[IndexRecord]) -> None: __poll_interval__: ClassVar[float] = 0.5 __max_retries__: ClassVar[int] = 3 __claim_timeout__: ClassVar[float] = 300.0 + __concurrency__: ClassVar[int] = 1 async def handle(self, event: E) -> None: """Handle a single event. Override for single-event processing. diff --git a/server/osa/domain/shared/model/source.py b/server/osa/domain/shared/model/source.py index 1cd0d6c..9d4b1ef 100644 --- a/server/osa/domain/shared/model/source.py +++ b/server/osa/domain/shared/model/source.py @@ -11,7 +11,7 @@ class IngesterLimits(ValueObject): """Resource limits for ingester container execution.""" timeout_seconds: int = 3600 - memory: str = "512m" + memory: str = "1g" cpu: str = "0.25" diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index d629adc..7a8e29d 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -45,8 +45,10 @@ class name as its consumer_group. Deliveries are claimed per consumer group, enabling multiple handlers to independently process the same event. """ - def __init__(self, handler_type: type[EventHandler[Any]]) -> None: + def __init__(self, handler_type: type[EventHandler[Any]], *, instance_id: int = 0) -> None: self._handler_type = handler_type + self._instance_id = instance_id + # All instances share the same consumer group so SKIP LOCKED distributes work self._consumer_group = handler_type.__name__ # Read config from handler classvars @@ -73,8 +75,10 @@ def __init__(self, handler_type: type[EventHandler[Any]]) -> None: @property def name(self) -> str: - """Worker name (handler class name).""" - return self._handler_type.__name__ + """Worker name (handler class name + instance suffix if concurrent).""" + if self._instance_id == 0: + return self._handler_type.__name__ + return f"{self._handler_type.__name__}-{self._instance_id}" @property def consumer_group(self) -> str: @@ -237,13 +241,28 @@ def workers(self) -> list[Worker]: return self._workers def register(self, handler_type: type[EventHandler[Any]]) -> Worker: - """Register an EventHandler type and create a Worker for it.""" - worker = Worker(handler_type) - if self._container is not None: - worker.set_container(self._container) - self._workers.append(worker) - logger.debug(f"Registered handler '{handler_type.__name__}' as worker") - return worker + """Register an EventHandler type and create Worker(s) for it. + + If the handler declares ``__concurrency__ = N``, spawns N workers + that share the same consumer group. Deliveries are distributed + across them via FOR UPDATE SKIP LOCKED. + """ + concurrency = getattr(handler_type, "__concurrency__", 1) + first_worker = None + for i in range(concurrency): + worker = Worker(handler_type, instance_id=i) + if self._container is not None: + worker.set_container(self._container) + self._workers.append(worker) + if first_worker is None: + first_worker = worker + if concurrency > 1: + logger.debug( + f"Registered handler '{handler_type.__name__}' with {concurrency} concurrent workers" + ) + else: + logger.debug(f"Registered handler '{handler_type.__name__}' as worker") + return first_worker # type: ignore[return-value] def add_worker(self, worker: Worker) -> None: """Add a worker to the pool.""" diff --git a/server/osa/infrastructure/persistence/repository/ingest.py b/server/osa/infrastructure/persistence/repository/ingest.py index f46fed2..a6ced87 100644 --- a/server/osa/infrastructure/persistence/repository/ingest.py +++ b/server/osa/infrastructure/persistence/repository/ingest.py @@ -30,6 +30,7 @@ async def save(self, ingest_run: IngestRun) -> None: "batches_completed": ingest_run.batches_completed, "published_count": ingest_run.published_count, "batch_size": ingest_run.batch_size, + "record_limit": ingest_run.limit, "started_at": ingest_run.started_at, "completed_at": ingest_run.completed_at, } @@ -122,6 +123,7 @@ def _row_to_ingest_run(row: dict) -> IngestRun: batches_completed=row["batches_completed"], published_count=row["published_count"], batch_size=row["batch_size"], + limit=row.get("record_limit"), started_at=row["started_at"], completed_at=row.get("completed_at"), ) diff --git a/server/osa/infrastructure/persistence/repository/record.py b/server/osa/infrastructure/persistence/repository/record.py index f143cf0..52a0176 100644 --- a/server/osa/infrastructure/persistence/repository/record.py +++ b/server/osa/infrastructure/persistence/repository/record.py @@ -1,6 +1,6 @@ """PostgreSQL implementation of RecordRepository.""" -from sqlalchemy import func, select +from sqlalchemy import func, select, text from sqlalchemy.dialects.postgresql import insert from sqlalchemy.ext.asyncio import AsyncSession @@ -37,9 +37,9 @@ async def save_many(self, records: list[Record]) -> list[Record]: .values(values) .on_conflict_do_nothing( index_elements=[ - records_table.c.source["type"].as_string(), - records_table.c.source["id"].as_string(), - ] + text("(source->>'type')"), + text("(source->>'id')"), + ], ) .returning(records_table.c.srn) ) diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index b7b9dce..a5b073c 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -318,6 +318,7 @@ Column("batches_completed", Integer, nullable=False, server_default=text("0")), Column("published_count", Integer, nullable=False, server_default=text("0")), Column("batch_size", Integer, nullable=False, server_default=text("1000")), + Column("record_limit", Integer, nullable=True), Column("started_at", DateTime(timezone=True), nullable=False), Column("completed_at", DateTime(timezone=True), nullable=True), ) diff --git a/server/tests/unit/application/test_app_factory.py b/server/tests/unit/application/test_app_factory.py index 609fb00..f0233d4 100644 --- a/server/tests/unit/application/test_app_factory.py +++ b/server/tests/unit/application/test_app_factory.py @@ -35,6 +35,7 @@ "OSA_AUTH__JWT__SECRET", "test-secret-that-is-at-least-32-characters-long", ) +os.environ.setdefault("OSA_BASE_URL", "http://localhost:8000") # --------------------------------------------------------------------------- diff --git a/server/tests/unit/config/test_config.py b/server/tests/unit/config/test_config.py index 426950f..e5fb1cb 100644 --- a/server/tests/unit/config/test_config.py +++ b/server/tests/unit/config/test_config.py @@ -24,6 +24,7 @@ def config_from_yaml(data: dict, env_overrides: dict[str, str] | None = None) -> raw = make_config_yaml(data) env = { "OSA_AUTH__JWT__SECRET": "test-secret-key-that-is-at-least-32-chars-long", + "OSA_BASE_URL": "http://localhost:8000", **(env_overrides or {}), } diff --git a/server/tests/unit/config/test_paths_config.py b/server/tests/unit/config/test_paths_config.py index 0e5bbc7..8006ebb 100644 --- a/server/tests/unit/config/test_paths_config.py +++ b/server/tests/unit/config/test_paths_config.py @@ -7,6 +7,13 @@ from osa.config import Config +@pytest.fixture(autouse=True) +def _config_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure required config env vars are set for all tests.""" + monkeypatch.setenv("OSA_AUTH__JWT__SECRET", "test-secret-key-that-is-at-least-32-chars-long") + monkeypatch.setenv("OSA_BASE_URL", "http://localhost:8000") + + class TestDatabaseUrlDerivation: """Tests for database URL derivation from OSAPaths.""" From 254ed994af2553544fe87e7d6ff88c050ae504ef Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Thu, 26 Mar 2026 13:19:49 +0000 Subject: [PATCH 5/9] feat: add configurable hook concurrency setting Add hook_concurrency config option to WorkerConfig and remove hardcoded concurrency from RunHooks handler. Update WorkerPool to use config-based concurrency overrides, allowing runtime configuration of hook worker parallelism instead of compile-time constants. --- server/osa/config.py | 1 + server/osa/domain/ingest/handler/run_hooks.py | 1 - server/osa/infrastructure/event/di.py | 4 ++- server/osa/infrastructure/event/worker.py | 29 +++++++++++++++---- 4 files changed, 28 insertions(+), 7 deletions(-) diff --git a/server/osa/config.py b/server/osa/config.py index 1cbea64..67f8f13 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -81,6 +81,7 @@ class WorkerConfig(BaseModel): poll_interval: float = 0.5 # Seconds between outbox polls batch_size: int = 100 # Maximum events to fetch per poll cycle + hook_concurrency: int = 8 # Number of concurrent hook workers class K8sConfig(BaseModel): diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index 3cf824f..c1b0cba 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -23,7 +23,6 @@ class RunHooks(EventHandler[IngesterBatchReady]): """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" __claim_timeout__ = 3600.0 # Hook runs can be long - __concurrency__ = 4 # Run hook containers for multiple batches in parallel ingest_repo: IngestRunRepository convention_service: ConventionService diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index e9c6d11..fb665c4 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -5,6 +5,7 @@ from dishka import AsyncContainer, provide +from osa.config import Config from osa.domain.curation.handler import AutoApproveCuration from osa.domain.deposition.handler import ReturnToDraft from osa.domain.feature.handler import ( @@ -131,12 +132,13 @@ def get_worker_pool( self, container: AsyncContainer, handler_types: HandlerTypes, + config: Config, ) -> WorkerPool: """WorkerPool with pull-based event handlers.""" pool = WorkerPool(container=container, stale_claim_interval=60.0) for handler_type in handler_types: - pool.register(handler_type) + pool.register(handler_type, config=config) logger.info(f"WorkerPool created with {len(pool.workers)} workers") return pool diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 7a8e29d..3594127 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -4,7 +4,10 @@ import logging from contextlib import AsyncExitStack from dataclasses import dataclass, field -from typing import Any, NewType +from typing import TYPE_CHECKING, Any, NewType + +if TYPE_CHECKING: + from osa.config import Config from apscheduler import AsyncScheduler from apscheduler.triggers.cron import CronTrigger @@ -240,14 +243,30 @@ def workers(self) -> list[Worker]: """List of managed workers.""" return self._workers - def register(self, handler_type: type[EventHandler[Any]]) -> Worker: + def register( + self, + handler_type: type[EventHandler[Any]], + config: "Config | None" = None, + ) -> Worker: """Register an EventHandler type and create Worker(s) for it. - If the handler declares ``__concurrency__ = N``, spawns N workers - that share the same consumer group. Deliveries are distributed - across them via FOR UPDATE SKIP LOCKED. + Concurrency is determined by (in priority order): + 1. Config override (e.g. ``config.worker.hook_concurrency`` for RunHooks) + 2. Handler classvar ``__concurrency__`` + 3. Default of 1 + + Multiple workers share the same consumer group so deliveries are + distributed across them via FOR UPDATE SKIP LOCKED. """ concurrency = getattr(handler_type, "__concurrency__", 1) + + # Apply config overrides + if config is not None: + from osa.domain.ingest.handler.run_hooks import RunHooks + + if handler_type is RunHooks: + concurrency = config.worker.hook_concurrency + first_worker = None for i in range(concurrency): worker = Worker(handler_type, instance_id=i) From 43c8383baa7f09f2478a52a0eef47ba9f8f18de7 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 27 Mar 2026 15:05:34 +0000 Subject: [PATCH 6/9] eat: add ingest_runs table migration for bulk ingestion tracking feat: replace Python logging with logfire structured logging system refactor: replace dict-based ingester records with typed IngesterRecord model feat: add comprehensive logging for ingest pipeline with batch progress tracking feat: increase default hook memory limit from 512m to 1g fix: sanitize record IDs in Docker bind mounts to avoid colon conflicts feat: add OOM detection and logging for hook container failures --- ...add_harvest_runs.py => add_ingest_runs.py} | 4 +- server/osa/application/api/rest/app.py | 47 ++++- server/osa/config.py | 10 +- .../feature/handler/insert_batch_features.py | 31 ++-- .../domain/ingest/handler/publish_batch.py | 149 +++++++++------- server/osa/domain/ingest/handler/run_hooks.py | 79 ++++---- .../osa/domain/ingest/handler/run_ingester.py | 34 ++-- .../domain/ingest/model/ingester_record.py | 17 ++ server/osa/domain/ingest/service/ingest.py | 12 +- server/osa/domain/record/service/record.py | 1 - server/osa/domain/shared/model/hook.py | 2 +- .../domain/validation/service/validation.py | 7 +- server/osa/infrastructure/event/worker.py | 4 +- server/osa/infrastructure/k8s/runner.py | 7 +- server/osa/infrastructure/logging.py | 168 ++++++++++++++++++ server/osa/infrastructure/oci/runner.py | 35 +++- .../unit/domain/shared/test_hook_models.py | 4 +- 17 files changed, 446 insertions(+), 165 deletions(-) rename server/migrations/versions/{add_harvest_runs.py => add_ingest_runs.py} (97%) create mode 100644 server/osa/domain/ingest/model/ingester_record.py create mode 100644 server/osa/infrastructure/logging.py diff --git a/server/migrations/versions/add_harvest_runs.py b/server/migrations/versions/add_ingest_runs.py similarity index 97% rename from server/migrations/versions/add_harvest_runs.py rename to server/migrations/versions/add_ingest_runs.py index 0d2f61c..816e6c2 100644 --- a/server/migrations/versions/add_harvest_runs.py +++ b/server/migrations/versions/add_ingest_runs.py @@ -2,7 +2,7 @@ Add ingest_runs table for bulk ingestion tracking. -Revision ID: add_harvest_runs +Revision ID: add_ingest_runs Revises: source_agnostic_records Create Date: 2026-03-25 @@ -14,7 +14,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "add_harvest_runs" +revision: str = "add_ingest_runs" down_revision: Union[str, Sequence[str], None] = "source_agnostic_records" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 3cfbc18..0f41bc0 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -1,4 +1,5 @@ import logging +import sys from contextlib import asynccontextmanager from typing import Any @@ -26,7 +27,7 @@ validation, ) from osa.application.di import create_container -from osa.config import Config, configure_logging +from osa.config import Config from osa.domain.shared.authorization.startup import validate_all_handlers from osa.domain.shared.error import OSAError from osa.domain.shared.event import EventHandler @@ -81,9 +82,41 @@ def create_app( # Pydantic Settings populates from env vars at runtime config = Config() # type: ignore[call-arg] - # Configure logging early - configure_logging(config.logging) - logger.info("Starting OSA server: %s v%s", config.name, config.version) + # Configure logfire as the single logging system + import logging as _logging + + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + + from osa.infrastructure.logging import OSAConsoleExporter + + logfire.configure( + send_to_logfire="if-token-present", + service_name=config.name, + console=False, # Disable default console — we use OSAConsoleExporter + inspect_arguments=False, + additional_span_processors=[ + SimpleSpanProcessor( + OSAConsoleExporter( + output=sys.stderr, + include_timestamp=True, + min_log_level=config.logging.level, + ) + ), + ], + ) + + # Route Python logging through logfire so old-style logger.info() calls + # appear in the same output stream + root = _logging.getLogger() + root.setLevel(config.logging.level.upper()) + for h in root.handlers[:]: + root.removeHandler(h) + root.addHandler(logfire.LogfireLoggingHandler()) + + # Suppress duplicate access logs — logfire FastAPI instrumentation handles HTTP logging + _logging.getLogger("uvicorn.access").setLevel(_logging.WARNING) + + logfire.info("Starting OSA server: {name} v{version}", name=config.name, version=config.version) # Validate all handlers have authorization declarations (fail fast) validate_all_handlers() @@ -95,9 +128,11 @@ def create_app( lifespan=lifespan, ) - # Instrument FastAPI for automatic tracing of HTTP requests logfire.instrument_httpx() - logfire.instrument_fastapi(app_instance) + logfire.instrument_fastapi( + app_instance, + excluded_urls="/api/v1/health", + ) # Setup dependency injection container = create_container( diff --git a/server/osa/config.py b/server/osa/config.py index 67f8f13..d9b9f9c 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -1,12 +1,13 @@ +from logfire import LevelName import logging import os import re import sys from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, Annotated import yaml -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, field_validator, model_validator, StringConstraints from pydantic_settings import BaseSettings, PydanticBaseSettingsSource from typing_extensions import Self @@ -63,7 +64,9 @@ class DatabaseConfig(BaseModel): class LoggingConfig(BaseModel): """Logging configuration (nested in Config, uses env_nested_delimiter).""" - level: str = "DEBUG" # Root log level (DEBUG for development) + level: Annotated[LevelName, StringConstraints(to_lower=True)] = ( + "debug" # Root log level (DEBUG for development) + ) format: str = "%(asctime)s %(levelname)-8s [%(name)s] %(message)s" date_format: str = "%Y-%m-%d %H:%M:%S" @@ -379,5 +382,6 @@ def configure_logging(config: LoggingConfig) -> None: logging.getLogger("asyncio").setLevel(logging.WARNING) logging.getLogger("aiosqlite").setLevel(logging.WARNING) logging.getLogger("apscheduler").setLevel(logging.WARNING) # Suppress job completion spam + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) # Logfire handles HTTP logging logging.debug("Logging configured: level=%s, file=%s", config.level, config.file) diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py index a548405..cf7f366 100644 --- a/server/osa/domain/feature/handler/insert_batch_features.py +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -1,14 +1,13 @@ """InsertBatchFeatures — bulk feature insertion for ingest batches.""" -import logging - from osa.domain.feature.port.storage import FeatureStoragePort from osa.domain.feature.service.feature import FeatureService from osa.domain.ingest.event.events import IngestBatchPublished from osa.domain.shared.event import EventHandler +from osa.infrastructure.logging import get_logger from osa.infrastructure.storage.layout import StorageLayout -logger = logging.getLogger(__name__) +log = get_logger(__name__) class InsertBatchFeatures(EventHandler[IngestBatchPublished]): @@ -32,6 +31,7 @@ async def handle(self, event: IngestBatchPublished) -> None: ) total_inserted = 0 + skipped_dupes = 0 for hook_name in event.expected_features: # Read JSONL outcomes for this hook @@ -46,11 +46,10 @@ async def handle(self, event: IngestBatchPublished) -> None: record_srn = event.upstream_to_record_srn.get(upstream_id) if not record_srn: - logger.warning( - "No record SRN mapping for upstream ID %s in batch %d", - upstream_id, - event.batch_index, - ) + # Expected for cross-batch duplicates — the record was already + # published in an earlier batch, so ON CONFLICT DO NOTHING + # skipped it and features were already inserted then. + skipped_dupes += 1 continue count = await self.feature_service.insert_features( @@ -60,10 +59,14 @@ async def handle(self, event: IngestBatchPublished) -> None: ) total_inserted += count - logger.info( - "Inserted %d feature rows for batch %d of %s (%d hooks)", - total_inserted, - event.batch_index, - event.ingest_run_srn, - len(event.expected_features), + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + dupe_msg = f", {skipped_dupes} duplicates skipped" if skipped_dupes else "" + log.info( + "[{short_id}] batch {batch_index}: inserted {total_inserted} feature rows ({hook_count} hooks{dupe_msg})", + short_id=short_id, + batch_index=event.batch_index, + total_inserted=total_inserted, + hook_count=len(event.expected_features), + dupe_msg=dupe_msg, + ingest_run_srn=event.ingest_run_srn, ) diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index 2c60445..4dbd500 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -1,18 +1,17 @@ """PublishBatch — reads hook outputs, bulk-publishes passing records.""" -import json -import logging from datetime import UTC, datetime -from pathlib import Path from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService +from osa.domain.feature.port.storage import FeatureStoragePort from osa.domain.ingest.event.events import ( HookBatchCompleted, IngestBatchPublished, IngestCompleted, ) from osa.domain.ingest.model.ingest_run import IngestStatus +from osa.domain.ingest.model.ingester_record import IngesterRecord from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.record.model.draft import RecordDraft from osa.domain.record.service import RecordService @@ -21,10 +20,10 @@ from osa.domain.shared.model.source import IngestSource from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox -from osa.domain.feature.port.storage import FeatureStoragePort +from osa.infrastructure.logging import get_logger from osa.infrastructure.storage.layout import StorageLayout -logger = logging.getLogger(__name__) +log = get_logger(__name__) class PublishBatch(EventHandler[HookBatchCompleted]): @@ -51,42 +50,62 @@ async def handle(self, event: HookBatchCompleted) -> None: ingester_dir = self.layout.ingest_batch_ingester_dir( event.ingest_run_srn, event.batch_index ) - ingester_records = _read_ingester_records(ingester_dir / "records.jsonl") + ingester_records = _read_ingester_records(ingester_dir) # Read hook outcomes for all hooks expected_features = [h.name for h in convention.hooks] - # Determine which records passed all hooks - passed_records = _get_passed_records( + # Determine which records passed all hooks (via storage port — works on filesystem + S3) + # TODO: is this efficient, are we hitting S3 a lot? + passed_records = await _get_passed_records( ingester_records=ingester_records, - batch_dir=batch_dir, + batch_dir=str(batch_dir), hooks=expected_features, feature_storage=self.feature_storage, ) + # Log outcome breakdown per hook + total = len(ingester_records) + for hook_name in expected_features: + outcomes = await self.feature_storage.read_batch_outcomes(str(batch_dir), hook_name) + passed = sum(1 for o in outcomes.values() if o.status == "passed") + rejected = sum(1 for o in outcomes.values() if o.status == "rejected") + errored = sum(1 for o in outcomes.values() if o.status == "errored") + missing = total - len(outcomes) + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] batch {batch_index} hook={hook_name}: " + "{passed}/{total} passed, {rejected} rejected, {errored} errored, {missing} missing", + short_id=short_id, + batch_index=event.batch_index, + hook_name=hook_name, + total=total, + passed=passed, + rejected=rejected, + errored=errored, + missing=missing, + ingest_run_srn=event.ingest_run_srn, + ) + published_count = 0 - if not passed_records: - logger.info("No passing records in batch %d", event.batch_index) - else: + if passed_records: # Construct RecordDrafts drafts: list[RecordDraft] = [] for record in passed_records: - source_id = record.get("source_id", record.get("id", "")) drafts.append( RecordDraft( source=IngestSource( - id=f"{ingest_run.convention_srn}:{source_id}", + id=f"{ingest_run.convention_srn}:{record.source_id}", ingest_run_srn=ingest_run.srn, - upstream_source=source_id, + upstream_source=record.source_id, ), - metadata=record.get("metadata", {}), + metadata=record.metadata, convention_srn=ConventionSRN.parse(ingest_run.convention_srn), expected_features=expected_features, ) ) - # Bulk publish — ON CONFLICT DO NOTHING skips duplicates, - # so published may be shorter than drafts + # Bulk publish — ON CONFLICT DO NOTHING skips duplicates published = await self.record_service.bulk_publish(drafts) published_srns = [str(r.srn) for r in published] published_count = len(published) @@ -94,14 +113,23 @@ async def handle(self, event: HookBatchCompleted) -> None: # Build upstream ID → record SRN mapping for feature insertion upstream_to_record_srn: dict[str, str] = {} for record in published: + # TODO: should we make RecordDraft generic over source type so we don't have to check this at runtime? + if not isinstance(record.source, IngestSource): + log.warn( + "Skipping record with unsupported source type: {source_type}", + source_type=type(record.source).__name__, + ) + continue upstream_to_record_srn[record.source.upstream_source] = str(record.srn) - logger.info( - "Published %d records from batch %d of %s (%d duplicates skipped)", - published_count, - event.batch_index, - event.ingest_run_srn, - len(drafts) - published_count, + log.info( + "[{short_id}] batch {batch_index}: published {published}/{passed} records ({duplicates} duplicates skipped)", + short_id=short_id, + batch_index=event.batch_index, + published=published_count, + passed=len(passed_records), + duplicates=len(drafts) - published_count, + ingest_run_srn=event.ingest_run_srn, ) # Emit IngestBatchPublished for feature insertion @@ -119,8 +147,7 @@ async def handle(self, event: HookBatchCompleted) -> None: ) ) - # Update counters atomically — use actual published_count (not passed_records) - # to avoid over-counting when ON CONFLICT DO NOTHING skips duplicates + # Update counters atomically updated = await self.ingest_repo.increment_completed( event.ingest_run_srn, published_count=published_count, @@ -138,16 +165,22 @@ async def handle(self, event: HookBatchCompleted) -> None: total_published=updated.published_count, ) ) - logger.info( - "Ingest completed: %s (total published: %d)", - event.ingest_run_srn, - updated.published_count, + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] COMPLETE: {total_published} records published", + short_id=short_id, + total_published=updated.published_count, + ingest_run_srn=event.ingest_run_srn, ) -def _read_ingester_records(records_file: Path) -> list[dict]: - """Read ingester records from JSONL file.""" - records: list[dict] = [] +def _read_ingester_records(ingester_dir) -> list[IngesterRecord]: + """Read ingester records from JSONL file into typed objects.""" + import json + from pathlib import Path + + records_file = Path(ingester_dir) / "records.jsonl" + records: list[IngesterRecord] = [] if not records_file.exists(): return records for line in records_file.open(): @@ -155,46 +188,36 @@ def _read_ingester_records(records_file: Path) -> list[dict]: if not line: continue try: - records.append(json.loads(line)) - except json.JSONDecodeError: - logger.warning("Skipping malformed ingester record line") + data = json.loads(line) + records.append( + IngesterRecord( + source_id=data.get("source_id", data.get("id", "")), + metadata=data.get("metadata", {}), + file_paths=data.get("file_paths", []), + ) + ) + except (json.JSONDecodeError, ValueError): + log.warn("Skipping malformed ingester record line") return records -def _get_passed_records( - ingester_records: list[dict], - batch_dir: Path, +async def _get_passed_records( + ingester_records: list[IngesterRecord], + batch_dir: str, hooks: list[str], feature_storage: FeatureStoragePort, -) -> list[dict]: - """Determine which records passed ALL hooks by intersecting features.jsonl across hooks. - - Each hook processes the full batch independently. A record must appear in - every hook's features.jsonl to be considered passed. Records rejected or - errored by any hook are excluded. - """ +) -> list[IngesterRecord]: + """Determine which records passed ALL hooks via the storage port.""" if not hooks: return ingester_records passed_ids: set[str] | None = None for hook_name in hooks: - features_file = batch_dir / "hooks" / hook_name / "output" / "features.jsonl" - if not features_file.exists(): - return [] # If any hook produced no features file, nothing passed - - hook_passed: set[str] = set() - for line in features_file.open(): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - record_id = data.get("id") - if record_id: - hook_passed.add(record_id) - except json.JSONDecodeError: - logger.warning("Skipping malformed features.jsonl line in hook %s", hook_name) + outcomes = await feature_storage.read_batch_outcomes(batch_dir, hook_name) + if not outcomes: + return [] + hook_passed = {rid for rid, o in outcomes.items() if o.status == "passed"} if passed_ids is None: passed_ids = hook_passed @@ -204,4 +227,4 @@ def _get_passed_records( if not passed_ids: return [] - return [r for r in ingester_records if r.get("source_id", r.get("id", "")) in passed_ids] + return [r for r in ingester_records if r.source_id in passed_ids] diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index c1b0cba..69979a7 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -1,12 +1,12 @@ """RunHooks — runs hook containers on an ingester batch.""" import json -import logging from pathlib import Path from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService from osa.domain.ingest.event.events import HookBatchCompleted, IngesterBatchReady +from osa.domain.ingest.model.ingester_record import IngesterRecord from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.shared.error import NotFoundError from osa.domain.shared.event import EventHandler, EventId @@ -14,15 +14,16 @@ from osa.domain.shared.outbox import Outbox from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.logging import get_logger from osa.infrastructure.storage.layout import StorageLayout -logger = logging.getLogger(__name__) +log = get_logger(__name__) class RunHooks(EventHandler[IngesterBatchReady]): """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" - __claim_timeout__ = 3600.0 # Hook runs can be long + __claim_timeout__ = 3600.0 ingest_repo: IngestRunRepository convention_service: ConventionService @@ -45,30 +46,23 @@ async def handle(self, event: IngesterBatchReady) -> None: ) records_file = ingester_dir / "records.jsonl" - records: list[dict] = [] - if records_file.exists(): - for line in records_file.open(): - line = line.strip() - if line: - try: - records.append(json.loads(line)) - except json.JSONDecodeError: - logger.warning( - "Skipping malformed record line in batch %d", event.batch_index - ) + records = _read_ingester_records(records_file) if not records: - logger.warning("No records in batch %d for %s", event.batch_index, event.ingest_run_srn) + log.warn( + "ingest batch {batch_index}: no records to process", + batch_index=event.batch_index, + ingest_run_srn=event.ingest_run_srn, + ) # Build files_dirs from ingester files files_base = ingester_dir / "files" files_dirs: dict[str, Path] = {} if files_base.exists(): for record in records: - record_id = record.get("source_id", record.get("id", "")) - record_files = files_base / str(record_id) + record_files = files_base / record.source_id if record_files.exists(): - files_dirs[str(record_id)] = record_files + files_dirs[record.source_id] = record_files # Run each hook sequentially for hook in convention.hooks: @@ -78,19 +72,24 @@ async def handle(self, event: IngesterBatchReady) -> None: hook_output_dir.mkdir(parents=True, exist_ok=True) inputs = HookInputs( - records=[ - HookRecord( - id=r.get("source_id", r.get("id", "")), - metadata=r.get("metadata", {}), - ) - for r in records - ], + records=[HookRecord(id=r.source_id, metadata=r.metadata) for r in records], run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", files_dirs=files_dirs, config=None, ) - await self.hook_runner.run(hook, inputs, hook_output_dir) + result = await self.hook_runner.run(hook, inputs, hook_output_dir) + + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] batch {batch_index} hook={hook_name}: {status} in {duration:.1f}s", + short_id=short_id, + batch_index=event.batch_index, + hook_name=hook.name, + status=result.status.value, + duration=result.duration_seconds, + ingest_run_srn=event.ingest_run_srn, + ) # Emit HookBatchCompleted await self.outbox.append( @@ -101,9 +100,25 @@ async def handle(self, event: IngesterBatchReady) -> None: ) ) - logger.info( - "Hooks completed for batch %d of %s (%d records processed)", - event.batch_index, - event.ingest_run_srn, - len(records), - ) + +def _read_ingester_records(records_file: Path) -> list[IngesterRecord]: + """Read ingester records from JSONL file into typed objects.""" + records: list[IngesterRecord] = [] + if not records_file.exists(): + return records + for line in records_file.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + records.append( + IngesterRecord( + source_id=data.get("source_id", data.get("id", "")), + metadata=data.get("metadata", {}), + file_paths=data.get("file_paths", []), + ) + ) + except (json.JSONDecodeError, ValueError): + log.warn("Skipping malformed ingester record line") + return records diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 7eb8daf..11e1615 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -1,7 +1,6 @@ """RunIngester — runs ingester container on IngestStarted or continuation.""" import json -import logging from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService @@ -13,15 +12,16 @@ from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterRunner +from osa.infrastructure.logging import get_logger from osa.infrastructure.storage.layout import StorageLayout -logger = logging.getLogger(__name__) +log = get_logger(__name__) class RunIngester(EventHandler[IngestStarted]): """Runs ingester container and emits IngesterBatchReady per batch.""" - __claim_timeout__ = 3600.0 # Ingester runs can be long + __claim_timeout__ = 3600.0 ingest_repo: IngestRunRepository convention_service: ConventionService @@ -34,7 +34,6 @@ async def handle(self, event: IngestStarted) -> None: if ingest_run is None: raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") - # Transition to RUNNING on first ingester pull if ingest_run.status == IngestStatus.PENDING: ingest_run.mark_running() await self.ingest_repo.save(ingest_run) @@ -45,34 +44,27 @@ async def handle(self, event: IngestStarted) -> None: if convention.ingester is None: raise NotFoundError(f"No ingester for convention {event.convention_srn}") - # Determine batch index from current batches_sourced batch_index = ingest_run.batches_sourced - # Prepare scratch directory batch_dir = self.layout.ingest_batch_ingester_dir(event.ingest_run_srn, batch_index) batch_dir.mkdir(parents=True, exist_ok=True) - # Load session state for continuation session_file = self.layout.ingest_session_file(event.ingest_run_srn) session = None if session_file.exists(): session = json.loads(session_file.read_text()) - # Compute effective limit for this batch - # If a total limit is set, don't request more than remaining effective_batch_limit = ingest_run.batch_size if ingest_run.limit is not None: sourced_so_far = ingest_run.batches_sourced * ingest_run.batch_size remaining = ingest_run.limit - sourced_so_far if remaining <= 0: - # Already sourced enough — mark finished await self.ingest_repo.increment_batches_sourced( event.ingest_run_srn, set_source_finished=True ) return effective_batch_limit = min(ingest_run.batch_size, remaining) - # Run ingester container inputs = IngesterInputs( convention_srn=convention.srn, config=convention.ingester.config, @@ -89,32 +81,27 @@ async def handle(self, event: IngestStarted) -> None: work_dir=batch_dir, ) - # Write records.jsonl to batch ingester dir records_file = batch_dir / "records.jsonl" with records_file.open("w") as f: for record in output.records: f.write(json.dumps(record) + "\n") - # Save session for continuation if output.session: session_file.parent.mkdir(parents=True, exist_ok=True) session_file.write_text(json.dumps(output.session)) has_more = output.session is not None and len(output.records) > 0 - # If total limit is set, check whether we've sourced enough if has_more and ingest_run.limit is not None: total_sourced = (ingest_run.batches_sourced + 1) * ingest_run.batch_size if total_sourced >= ingest_run.limit: has_more = False - # Update counters atomically await self.ingest_repo.increment_batches_sourced( event.ingest_run_srn, set_source_finished=not has_more, ) - # Emit batch ready event await self.outbox.append( IngesterBatchReady( id=EventId(uuid4()), @@ -124,15 +111,16 @@ async def handle(self, event: IngestStarted) -> None: ) ) - logger.info( - "Ingester batch %d ready for %s (%d records, has_more=%s)", - batch_index, - event.ingest_run_srn, - len(output.records), - has_more, + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + log.info( + "[{short_id}] batch {batch_index}: pulled {record_count} records (has_more={has_more})", + short_id=short_id, + batch_index=batch_index, + record_count=len(output.records), + has_more=has_more, + ingest_run_srn=event.ingest_run_srn, ) - # Emit continuation event for next batch if has_more: await self.outbox.append( IngestStarted( diff --git a/server/osa/domain/ingest/model/ingester_record.py b/server/osa/domain/ingest/model/ingester_record.py new file mode 100644 index 0000000..110d976 --- /dev/null +++ b/server/osa/domain/ingest/model/ingester_record.py @@ -0,0 +1,17 @@ +"""IngesterRecord — typed representation of a record from an ingester container.""" + +from typing import Any + +from osa.domain.shared.model.value import ValueObject + + +class IngesterRecord(ValueObject): + """A record produced by an ingester container, parsed from records.jsonl. + + Replaces raw dicts with typed fields so downstream handlers + don't need fragile `.get("source_id", .get("id", ""))` patterns. + """ + + source_id: str + metadata: dict[str, Any] + file_paths: list[str] = [] diff --git a/server/osa/domain/ingest/service/ingest.py b/server/osa/domain/ingest/service/ingest.py index 2a58419..35c338a 100644 --- a/server/osa/domain/ingest/service/ingest.py +++ b/server/osa/domain/ingest/service/ingest.py @@ -1,6 +1,5 @@ """IngestService — orchestrates ingest lifecycle.""" -import logging from datetime import UTC, datetime from uuid import uuid4 @@ -13,8 +12,9 @@ from osa.domain.shared.model.srn import ConventionSRN, Domain from osa.domain.shared.outbox import Outbox from osa.domain.shared.service import Service +from osa.infrastructure.logging import get_logger -logger = logging.getLogger(__name__) +log = get_logger(__name__) class IngestService(Service): @@ -77,5 +77,11 @@ async def start_ingest( ) ) - logger.info("Ingest started: %s for convention %s", srn, convention_srn) + log.info( + "ingest started for {convention_srn}", + ingest_run_srn=srn, + convention_srn=convention_srn, + batch_size=batch_size, + limit=limit, + ) return ingest_run diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index 844045b..e2409ce 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -77,7 +77,6 @@ async def bulk_publish(self, drafts: list[RecordDraft]) -> list[Record]: ) published = await self.record_repo.save_many(records) - logger.info("Bulk-published %d records (of %d drafts)", len(published), len(drafts)) return published async def publish_record(self, draft: RecordDraft) -> Record: diff --git a/server/osa/domain/shared/model/hook.py b/server/osa/domain/shared/model/hook.py index 7fe5e29..1a90db7 100644 --- a/server/osa/domain/shared/model/hook.py +++ b/server/osa/domain/shared/model/hook.py @@ -32,7 +32,7 @@ class OciLimits(ValueObject): """Resource limits for OCI hook execution.""" timeout_seconds: int = 300 - memory: str = "512m" + memory: str = "1g" cpu: str = "0.5" diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index c530b3c..3a37a9c 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -106,14 +106,15 @@ async def validate_deposition( Uses the unified batch contract: constructs a 1-record batch for depositions. """ - record = HookRecord(id=str(deposition_srn), metadata=metadata) - run_id = f"{deposition_srn.domain.root}_{deposition_srn.id.root}" + local_id = deposition_srn.id.root + record = HookRecord(id=local_id, metadata=metadata) + run_id = f"{deposition_srn.domain.root}_{local_id}" files_dir = self.hook_storage.get_files_dir(deposition_srn) inputs = HookInputs( records=[record], run_id=run_id, - files_dirs={str(deposition_srn): files_dir} if files_dir else {}, + files_dirs={local_id: files_dir} if files_dir else {}, ) run = await self.create_run(inputs=inputs) diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 3594127..7721c2a 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -112,9 +112,11 @@ def start(self) -> asyncio.Task: if self._container is None: raise RuntimeError("Container not set. Call set_container() first.") + import logfire + self._shutdown = False self._task = asyncio.create_task(self._run(), name=f"worker-{self.name}") - logger.info(f"Worker '{self.name}' started") + logfire.info("worker started: {worker_name}", worker_name=self.name) return self._task def stop(self) -> None: diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index b03f91c..1f0a040 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -115,11 +115,12 @@ async def _run_job( job_name_to_watch = existing.split(":", 1)[1] else: # Create new Job (no existing or failed) - # For depositions (batch of 1), use the single record's files dir + # Mount the parent of all per-record file dirs — works for + # both depositions (one subdir) and ingests (N subdirs) files_dir = None if inputs.files_dirs: - first_id = next(iter(inputs.files_dirs)) - files_dir = inputs.files_dirs[first_id] + first_dir = next(iter(inputs.files_dirs.values())) + files_dir = first_dir.parent spec = self._build_job_spec( hook, work_dir, diff --git a/server/osa/infrastructure/logging.py b/server/osa/infrastructure/logging.py new file mode 100644 index 0000000..90070b0 --- /dev/null +++ b/server/osa/infrastructure/logging.py @@ -0,0 +1,168 @@ +"""OSA logging — custom logfire console exporter and structured logger. + +Provides: +- ``OSAConsoleExporter``: logfire console formatter with aligned columns + (timestamp, level, module, message) and indented continuation lines. +- ``get_logger(name)``: thin wrapper around logfire that auto-tags with module name. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import TYPE_CHECKING, Any, cast + +import logfire as _logfire +from logfire._internal.exporters.console import ( + ATTRIBUTES_TAGS_KEY, + ONE_SECOND_IN_NANOSECONDS, + SimpleConsoleSpanExporter, + _ERROR_LEVEL, + _WARN_LEVEL, +) + +if TYPE_CHECKING: + from logfire._internal.exporters.console import Record, TextParts + + +_LEVEL_NAMES: dict[int, str] = { + 0: "TRACE", + 1: "DEBUG", + 5: "DEBUG", + 9: "INFO ", + 13: "WARN ", + 17: "ERROR", + 21: "FATAL", +} + +# Module column width (inside brackets) +_MODULE_WIDTH = 20 +# Fixed prefix: HH:MM:SS.mmm(12) + space(1) + LEVEL(5) + space(1) + [module](22) + space(1) = 42 +_PREFIX_WIDTH = 42 + + +class OSAConsoleExporter(SimpleConsoleSpanExporter): + """Logfire console exporter with aligned columns. + + Format: ``HH:MM:SS.mmm LEVEL module.name message`` + + Continuation lines are indented to align with the message column. + Module name is extracted from ``_tags`` and shortened for readability. + """ + + def _span_text_parts(self, span: Record, indent: int) -> tuple[str, TextParts]: + parts: TextParts = [] + + # Timestamp + if self._include_timestamp: + ts = datetime.fromtimestamp(span.timestamp / ONE_SECOND_IN_NANOSECONDS) + ts_str = f"{ts:%H:%M:%S.%f}"[:-3] + parts += [(ts_str, "green"), (" ", "")] + + # Level (fixed 5 chars) + level: int = span.level + level_name = _LEVEL_NAMES.get(level, f"L{level:<3}") + if level >= _ERROR_LEVEL: + level_style = "red" + elif level >= _WARN_LEVEL: + level_style = "yellow" + else: + level_style = "dim" + parts += [(level_name, level_style), (" ", "")] + + # Module tag in brackets, fixed width + if self._include_tags: + tags = span.attributes.get(ATTRIBUTES_TAGS_KEY) + if tags: + tag = cast("list[str]", tags)[0] + short = _shorten_module(tag) + bracketed = f"[{short}]" + parts += [(f"{bracketed:<{_MODULE_WIDTH + 2}}", "cyan"), (" ", "")] + else: + parts += [(" " * (_MODULE_WIDTH + 3), "")] + + if indent: + parts += [(indent * " ", "")] + + # Message with aligned continuation lines + msg: str = span.message + pad = " " * _PREFIX_WIDTH + msg = msg.replace("\n", "\n" + pad) + + if level >= _ERROR_LEVEL: + parts += [(msg, "red")] + elif level >= _WARN_LEVEL: + parts += [(msg, "yellow")] + else: + parts += [(msg, "")] + + return msg, parts + + +def _shorten_module(name: str) -> str: + """Shorten module path to fit ~20 chars. + + ``osa.domain.ingest.handler.run_ingester`` → ``ingest.run_ingester`` + ``osa.infrastructure.oci.runner`` → ``infra.oci.runner`` + ``osa.domain.feature.handler.insert_batch_features`` → ``feat.ins_batch_feat`` + """ + short = ( + name.replace("osa.domain.", "") + .replace("osa.infrastructure.", "infra.") + .replace(".handler.", ".") + .replace(".service.", ".") + .replace(".util.", ".") + ) + if len(short) <= _MODULE_WIDTH: + return short + # Truncate: keep first and last segment, abbreviate middle + parts = short.split(".") + if len(parts) <= 2: + return short[:_MODULE_WIDTH] + # Keep first and last, drop middle segments until it fits + first, *middle, last = parts + while middle and len(f"{first}.{'.'.join(middle)}.{last}") > _MODULE_WIDTH: + middle.pop(0) + result = f"{first}.{'.'.join(middle)}.{last}" if middle else f"{first}.{last}" + return result[:_MODULE_WIDTH] + + +class Logger: + """Structured logger that wraps logfire with automatic module tagging. + + Usage:: + + from osa.infrastructure.logging import get_logger + + log = get_logger(__name__) + log.info("batch {idx}: pulled {n} records", idx=0, n=101) + + Produces structured logfire spans with ``_tags=[module_name]``, + preserving key-value attributes for logfire cloud while showing + the module name in console output. + """ + + __slots__ = ("_tags",) + + def __init__(self, name: str) -> None: + self._tags = [name] + + def info(self, msg: str, **kwargs: Any) -> None: + _logfire.info(msg, _tags=self._tags, **kwargs) + + def warn(self, msg: str, **kwargs: Any) -> None: + _logfire.warn(msg, _tags=self._tags, **kwargs) + + def error(self, msg: str, **kwargs: Any) -> None: + _logfire.error(msg, _tags=self._tags, **kwargs) + + def debug(self, msg: str, **kwargs: Any) -> None: + _logfire.debug(msg, _tags=self._tags, **kwargs) + + +def get_logger(name: str) -> Logger: + """Create a structured logger for a module. + + Args: + name: Module name, typically ``__name__``. + """ + return Logger(name) diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index f19e8b8..362ceb9 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -4,16 +4,16 @@ import json import os import stat +import sys import time from pathlib import Path from shutil import rmtree import aiodocker -import logfire - from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.logging import get_logger from osa.infrastructure.runner_utils import ( detect_rejection, parse_memory, @@ -27,6 +27,9 @@ def _force_remove(func, path, exc): func(path) +log = get_logger(__name__) + + class OciHookRunner(HookRunner): """Executes hooks in OCI containers via aiodocker.""" @@ -97,7 +100,7 @@ async def _resolve_and_run(): ) except asyncio.TimeoutError: duration = time.monotonic() - start_time - logfire.error( + log.error( "Hook timed out", hook=hook.name, run_id=inputs.run_id, @@ -130,10 +133,12 @@ async def _run_container( ] # Mount per-record file directories under /osa/files/{id}/ + # Sanitize IDs to avoid colons breaking Docker's bind mount syntax if files_dirs: for record_id, fdir in files_dirs.items(): if fdir and fdir.exists(): - binds.append(f"{self._host_path(fdir)}:/osa/files/{record_id}:ro") + safe_id = record_id.replace(":", "_").replace("@", "_") + binds.append(f"{self._host_path(fdir)}:/osa/files/{safe_id}:ro") elif files_base.exists(): binds.append(f"{self._host_path(files_base)}:/osa/files:ro") @@ -172,9 +177,23 @@ async def _run_container( oom_killed = inspect_data.get("State", {}).get("OOMKilled", False) if oom_killed: + # Grab tail of container logs before deletion + try: + tail_logs = await container.log(stdout=True, stderr=True, tail=3) + tail_text = "".join(tail_logs).strip() if tail_logs else "" + except Exception: + tail_text = "" + log.error( + "OOM: hook={hook_name} limit={memory}", + hook_name=hook.name, + memory=hook.runtime.limits.memory, + ) + if tail_text: + for line in tail_text.splitlines(): + print(f" OOM [{hook.name}] {line}", file=sys.stderr, flush=True) return { "status": HookStatus.FAILED, - "error_message": "Hook killed by OOM", + "error_message": f"Hook killed by OOM (limit: {hook.runtime.limits.memory})", } # Parse progress file @@ -204,13 +223,13 @@ async def _run_container( } except aiodocker.DockerError as e: - logfire.error("Docker error running hook", error=str(e)) + log.error("Docker error running hook", error=str(e)) return { "status": HookStatus.FAILED, "error_message": f"Docker error: {e}", } except Exception as e: - logfire.error("Unexpected error running hook", error=str(e)) + log.error("Unexpected error running hook", error=str(e)) return { "status": HookStatus.FAILED, "error_message": f"Unexpected error: {e}", @@ -247,6 +266,6 @@ async def _resolve_image(self, image: str, digest: str) -> str: pass # Pull from registry as last resort - logfire.info("Pulling hook image", image=image) + log.info("Pulling hook image", image=image) await self._docker.images.pull(image) return image diff --git a/server/tests/unit/domain/shared/test_hook_models.py b/server/tests/unit/domain/shared/test_hook_models.py index 249d06f..cd44924 100644 --- a/server/tests/unit/domain/shared/test_hook_models.py +++ b/server/tests/unit/domain/shared/test_hook_models.py @@ -92,7 +92,7 @@ def test_oci_limits_defaults(): limits = OciLimits() assert limits.timeout_seconds == 300 - assert limits.memory == "512m" + assert limits.memory == "1g" assert limits.cpu == "0.5" @@ -176,7 +176,7 @@ def test_hook_definition_default_limits(): feature=TableFeatureSpec(cardinality="one", columns=[]), ) assert hook_def.runtime.limits.timeout_seconds == 300 - assert hook_def.runtime.limits.memory == "512m" + assert hook_def.runtime.limits.memory == "1g" def test_hook_definition_serialization_roundtrip(): From 38a988850c4b7a7f4bbb1db174fc71a54a8d6fb7 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Tue, 31 Mar 2026 23:54:34 +0100 Subject: [PATCH 7/9] refactor: rename source-related fields to ingestion-related fields Rename source_finished to ingestion_finished and batches_sourced to batches_ingested across the codebase to better reflect the actual ingestion process rather than just sourcing data --- server/migrations/versions/add_ingest_runs.py | 4 +-- .../osa/domain/ingest/handler/run_ingester.py | 20 ++++++----- server/osa/domain/ingest/model/ingest_run.py | 14 ++++---- server/osa/domain/ingest/port/repository.py | 6 ++-- .../persistence/repository/ingest.py | 20 +++++------ .../osa/infrastructure/persistence/tables.py | 4 +-- .../unit/domain/ingest/test_ingest_run.py | 34 +++++++++---------- 7 files changed, 53 insertions(+), 49 deletions(-) diff --git a/server/migrations/versions/add_ingest_runs.py b/server/migrations/versions/add_ingest_runs.py index 816e6c2..6a23b7c 100644 --- a/server/migrations/versions/add_ingest_runs.py +++ b/server/migrations/versions/add_ingest_runs.py @@ -37,13 +37,13 @@ def upgrade() -> None: server_default=sa.text("'pending'"), ), sa.Column( - "source_finished", + "ingestion_finished", sa.Boolean(), nullable=False, server_default=sa.text("false"), ), sa.Column( - "batches_sourced", + "batches_ingested", sa.Integer(), nullable=False, server_default=sa.text("0"), diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 11e1615..00eb8c5 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -30,6 +30,10 @@ class RunIngester(EventHandler[IngestStarted]): layout: StorageLayout async def handle(self, event: IngestStarted) -> None: + """Run ingester for the given ingest run and emit IngesterBatchReady. + + TODO: move this log into a service method. + """ ingest_run = await self.ingest_repo.get(event.ingest_run_srn) if ingest_run is None: raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") @@ -44,7 +48,7 @@ async def handle(self, event: IngestStarted) -> None: if convention.ingester is None: raise NotFoundError(f"No ingester for convention {event.convention_srn}") - batch_index = ingest_run.batches_sourced + batch_index = ingest_run.batches_ingested batch_dir = self.layout.ingest_batch_ingester_dir(event.ingest_run_srn, batch_index) batch_dir.mkdir(parents=True, exist_ok=True) @@ -56,11 +60,11 @@ async def handle(self, event: IngestStarted) -> None: effective_batch_limit = ingest_run.batch_size if ingest_run.limit is not None: - sourced_so_far = ingest_run.batches_sourced * ingest_run.batch_size - remaining = ingest_run.limit - sourced_so_far + ingested_so_far = ingest_run.batches_ingested * ingest_run.batch_size + remaining = ingest_run.limit - ingested_so_far if remaining <= 0: - await self.ingest_repo.increment_batches_sourced( - event.ingest_run_srn, set_source_finished=True + await self.ingest_repo.increment_batches_ingested( + event.ingest_run_srn, set_ingestion_finished=True ) return effective_batch_limit = min(ingest_run.batch_size, remaining) @@ -93,13 +97,13 @@ async def handle(self, event: IngestStarted) -> None: has_more = output.session is not None and len(output.records) > 0 if has_more and ingest_run.limit is not None: - total_sourced = (ingest_run.batches_sourced + 1) * ingest_run.batch_size + total_sourced = (ingest_run.batches_ingested + 1) * ingest_run.batch_size if total_sourced >= ingest_run.limit: has_more = False - await self.ingest_repo.increment_batches_sourced( + await self.ingest_repo.increment_batches_ingested( event.ingest_run_srn, - set_source_finished=not has_more, + set_ingestion_finished=not has_more, ) await self.outbox.append( diff --git a/server/osa/domain/ingest/model/ingest_run.py b/server/osa/domain/ingest/model/ingest_run.py index 02642b2..c95f604 100644 --- a/server/osa/domain/ingest/model/ingest_run.py +++ b/server/osa/domain/ingest/model/ingest_run.py @@ -32,8 +32,8 @@ class IngestRun(Aggregate): srn: str convention_srn: str status: IngestStatus = IngestStatus.PENDING - source_finished: bool = False - batches_sourced: int = 0 + ingestion_finished: bool = False + batches_ingested: int = 0 batches_completed: int = 0 published_count: int = 0 batch_size: int = 1000 @@ -54,11 +54,11 @@ def mark_failed(self, completed_at: datetime) -> None: self.transition_to(IngestStatus.FAILED) self.completed_at = completed_at - def mark_source_finished(self) -> None: - self.source_finished = True + def mark_ingestion_finished(self) -> None: + self.ingestion_finished = True - def increment_batches_sourced(self) -> None: - self.batches_sourced += 1 + def increment_batches_ingested(self) -> None: + self.batches_ingested += 1 def record_batch_completed(self, published_count: int) -> None: """Record a completed batch with its published count. @@ -72,7 +72,7 @@ def record_batch_completed(self, published_count: int) -> None: @property def is_complete(self) -> bool: """Check the completion condition: all sourced batches are completed.""" - return self.source_finished and self.batches_sourced == self.batches_completed + return self.ingestion_finished and self.batches_ingested == self.batches_completed def check_completion(self, completed_at: datetime) -> bool: """Check completion condition and transition if met. diff --git a/server/osa/domain/ingest/port/repository.py b/server/osa/domain/ingest/port/repository.py index ecbfbd0..f5839d6 100644 --- a/server/osa/domain/ingest/port/repository.py +++ b/server/osa/domain/ingest/port/repository.py @@ -31,10 +31,10 @@ async def get_running_for_convention(self, convention_srn: str) -> IngestRun | N ... @abstractmethod - async def increment_batches_sourced( - self, srn: str, *, set_source_finished: bool = False + async def increment_batches_ingested( + self, srn: str, *, set_ingestion_finished: bool = False ) -> IngestRun: - """Atomically increment batches_sourced and optionally set source_finished. + """Atomically increment batches_ingested and optionally set ingestion_finished. Returns the updated IngestRun with DB-authoritative counter values. """ diff --git a/server/osa/infrastructure/persistence/repository/ingest.py b/server/osa/infrastructure/persistence/repository/ingest.py index a6ced87..94d9beb 100644 --- a/server/osa/infrastructure/persistence/repository/ingest.py +++ b/server/osa/infrastructure/persistence/repository/ingest.py @@ -25,8 +25,8 @@ async def save(self, ingest_run: IngestRun) -> None: "srn": ingest_run.srn, "convention_srn": ingest_run.convention_srn, "status": ingest_run.status.value, - "source_finished": ingest_run.source_finished, - "batches_sourced": ingest_run.batches_sourced, + "ingestion_finished": ingest_run.ingestion_finished, + "batches_ingested": ingest_run.batches_ingested, "batches_completed": ingest_run.batches_completed, "published_count": ingest_run.published_count, "batch_size": ingest_run.batch_size, @@ -70,16 +70,16 @@ async def get_running_for_convention(self, convention_srn: str) -> IngestRun | N return None return _row_to_ingest_run(dict(row)) - async def increment_batches_sourced( - self, srn: str, *, set_source_finished: bool = False + async def increment_batches_ingested( + self, srn: str, *, set_ingestion_finished: bool = False ) -> IngestRun: - """Atomically increment batches_sourced.""" + """Atomically increment batches_ingested.""" t = ingest_runs_table values = { - "batches_sourced": t.c.batches_sourced + 1, + "batches_ingested": t.c.batches_ingested + 1, } - if set_source_finished: - values["source_finished"] = True + if set_ingestion_finished: + values["ingestion_finished"] = True stmt = update(t).where(t.c.srn == srn).values(**values).returning(*t.c) result = await self._session.execute(stmt) @@ -118,8 +118,8 @@ def _row_to_ingest_run(row: dict) -> IngestRun: srn=row["srn"], convention_srn=row["convention_srn"], status=IngestStatus(row["status"]), - source_finished=row["source_finished"], - batches_sourced=row["batches_sourced"], + ingestion_finished=row["ingestion_finished"], + batches_ingested=row["batches_ingested"], batches_completed=row["batches_completed"], published_count=row["published_count"], batch_size=row["batch_size"], diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index a5b073c..4f3545a 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -313,8 +313,8 @@ Column("srn", String, primary_key=True), Column("convention_srn", String, ForeignKey("conventions.srn"), nullable=False), Column("status", String(32), nullable=False, server_default=text("'pending'")), - Column("source_finished", Boolean, nullable=False, server_default=text("false")), - Column("batches_sourced", Integer, nullable=False, server_default=text("0")), + Column("ingestion_finished", Boolean, nullable=False, server_default=text("false")), + Column("batches_ingested", Integer, nullable=False, server_default=text("0")), Column("batches_completed", Integer, nullable=False, server_default=text("0")), Column("published_count", Integer, nullable=False, server_default=text("0")), Column("batch_size", Integer, nullable=False, server_default=text("1000")), diff --git a/server/tests/unit/domain/ingest/test_ingest_run.py b/server/tests/unit/domain/ingest/test_ingest_run.py index b1e607b..55f5395 100644 --- a/server/tests/unit/domain/ingest/test_ingest_run.py +++ b/server/tests/unit/domain/ingest/test_ingest_run.py @@ -61,8 +61,8 @@ class TestCompletionCondition: def test_not_complete_when_source_not_finished(self) -> None: run = _make_run( status=IngestStatus.RUNNING, - source_finished=False, - batches_sourced=3, + ingestion_finished=False, + batches_ingested=3, batches_completed=3, ) assert not run.is_complete @@ -70,8 +70,8 @@ def test_not_complete_when_source_not_finished(self) -> None: def test_not_complete_when_batches_pending(self) -> None: run = _make_run( status=IngestStatus.RUNNING, - source_finished=True, - batches_sourced=3, + ingestion_finished=True, + batches_ingested=3, batches_completed=2, ) assert not run.is_complete @@ -79,8 +79,8 @@ def test_not_complete_when_batches_pending(self) -> None: def test_complete_when_all_batches_done(self) -> None: run = _make_run( status=IngestStatus.RUNNING, - source_finished=True, - batches_sourced=3, + ingestion_finished=True, + batches_ingested=3, batches_completed=3, ) assert run.is_complete @@ -88,8 +88,8 @@ def test_complete_when_all_batches_done(self) -> None: def test_check_completion_transitions_status(self) -> None: run = _make_run( status=IngestStatus.RUNNING, - source_finished=True, - batches_sourced=2, + ingestion_finished=True, + batches_ingested=2, batches_completed=2, ) now = datetime.now(UTC) @@ -101,8 +101,8 @@ def test_check_completion_transitions_status(self) -> None: def test_check_completion_noop_when_not_complete(self) -> None: run = _make_run( status=IngestStatus.RUNNING, - source_finished=True, - batches_sourced=3, + ingestion_finished=True, + batches_ingested=3, batches_completed=2, ) completed = run.check_completion(datetime.now(UTC)) @@ -111,10 +111,10 @@ def test_check_completion_noop_when_not_complete(self) -> None: class TestCounterIncrements: - def test_increment_batches_sourced(self) -> None: + def test_increment_batches_ingested(self) -> None: run = _make_run(status=IngestStatus.RUNNING) - run.increment_batches_sourced() - assert run.batches_sourced == 1 + run.increment_batches_ingested() + assert run.batches_ingested == 1 def test_record_batch_completed(self) -> None: run = _make_run(status=IngestStatus.RUNNING) @@ -129,11 +129,11 @@ def test_multiple_batch_completions(self) -> None: assert run.batches_completed == 2 assert run.published_count == 150 - def test_mark_source_finished(self) -> None: + def test_mark_ingestion_finished(self) -> None: run = _make_run(status=IngestStatus.RUNNING) - assert not run.source_finished - run.mark_source_finished() - assert run.source_finished + assert not run.ingestion_finished + run.mark_ingestion_finished() + assert run.ingestion_finished def test_batch_size_default(self) -> None: run = _make_run() From 5bcb944c3a59081861959015a36e478a40755168 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 3 Apr 2026 14:49:47 +0100 Subject: [PATCH 8/9] refactor: replace string-based status checks with typed enums Replace hardcoded status strings ("passed", "rejected", "errored") with OutcomeStatus enum throughout validation domain. Add HookRecordId type alias for better type safety in batch outcome handling. feat: add OOM status to HookStatus enum with oom_killed property Add HookStatus.OOM variant to distinguish OOM failures from general failures. Include oom_killed property on HookResult for convenient OOM detection. feat: implement HookService with OOM retry and checkpointing Add comprehensive hook execution service that handles OOM conditions by retrying with doubled memory limits up to MAX_OOM_RETRIES times. Implements checkpointing for crash recovery and sorts records by size to maximize progress before potential OOM on large files. refactor: consolidate ingester record parsing into model class Move duplicate JSONL parsing logic from handlers into IngesterRecord.from_jsonl class method. Add IngesterFileRef model and total_file_mb property for better file size tracking. feat: add memory parsing and doubling utilities to hook model Add parse_memory function and with_doubled_memory method to HookDefinition for dynamic memory limit adjustment during OOM retry. refactor: replace StringConstraints with BeforeValidator in config Update logging level validation to use BeforeValidator instead of deprecated StringConstraints for Pydantic v2 compatibility. --- server/osa/config.py | 8 +- server/osa/domain/feature/port/storage.py | 4 +- .../domain/ingest/handler/publish_batch.py | 41 +- server/osa/domain/ingest/handler/run_hooks.py | 76 +-- .../domain/ingest/model/ingester_record.py | 45 +- server/osa/domain/shared/model/hook.py | 52 ++ .../domain/validation/model/batch_outcome.py | 17 +- .../osa/domain/validation/model/hook_input.py | 1 + .../domain/validation/model/hook_result.py | 6 + server/osa/domain/validation/port/storage.py | 24 + server/osa/domain/validation/service/hook.py | 239 ++++++++ .../domain/validation/service/validation.py | 13 +- .../osa/domain/validation/util/di/provider.py | 2 + server/osa/infrastructure/k8s/runner.py | 2 +- server/osa/infrastructure/oci/runner.py | 2 +- .../persistence/adapter/storage.py | 130 +++- server/osa/infrastructure/runner_utils.py | 24 +- server/osa/infrastructure/s3/storage.py | 29 +- .../domain/ingest/test_ingester_record.py | 82 +++ .../unit/domain/shared/test_hook_models.py | 54 ++ .../domain/validation/test_hook_result.py | 38 ++ .../domain/validation/test_hook_service.py | 567 ++++++++++++++++++ .../validation/test_validation_service.py | 53 ++ .../k8s/test_k8s_hook_runner.py | 2 +- .../infrastructure/test_oci_hook_runner.py | 4 +- 25 files changed, 1360 insertions(+), 155 deletions(-) create mode 100644 server/osa/domain/validation/service/hook.py create mode 100644 server/tests/unit/domain/ingest/test_ingester_record.py create mode 100644 server/tests/unit/domain/validation/test_hook_service.py diff --git a/server/osa/config.py b/server/osa/config.py index d9b9f9c..ffeda85 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -7,7 +7,7 @@ from typing import Any, Literal, Annotated import yaml -from pydantic import BaseModel, field_validator, model_validator, StringConstraints +from pydantic import BaseModel, BeforeValidator, field_validator, model_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource from typing_extensions import Self @@ -64,9 +64,9 @@ class DatabaseConfig(BaseModel): class LoggingConfig(BaseModel): """Logging configuration (nested in Config, uses env_nested_delimiter).""" - level: Annotated[LevelName, StringConstraints(to_lower=True)] = ( - "debug" # Root log level (DEBUG for development) - ) + level: Annotated[ + LevelName, BeforeValidator(lambda v: v.lower() if isinstance(v, str) else v) + ] = "debug" # Root log level (DEBUG for development) format: str = "%(asctime)s %(levelname)-8s [%(name)s] %(message)s" date_format: str = "%Y-%m-%d %H:%M:%S" diff --git a/server/osa/domain/feature/port/storage.py b/server/osa/domain/feature/port/storage.py index 4856c17..64caa81 100644 --- a/server/osa/domain/feature/port/storage.py +++ b/server/osa/domain/feature/port/storage.py @@ -4,7 +4,7 @@ from typing import Any, Protocol from osa.domain.shared.port import Port -from osa.domain.validation.model.batch_outcome import BatchRecordOutcome +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome, HookRecordId class FeatureStoragePort(Port, Protocol): @@ -34,7 +34,7 @@ async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> @abstractmethod async def read_batch_outcomes( self, output_dir: str, hook_name: str - ) -> dict[str, BatchRecordOutcome]: + ) -> dict[HookRecordId, BatchRecordOutcome]: """Read JSONL batch outputs (features/rejections/errors) for a hook. Parses features.jsonl, rejections.jsonl, and errors.jsonl from the diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index 4dbd500..a363f70 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -50,7 +50,7 @@ async def handle(self, event: HookBatchCompleted) -> None: ingester_dir = self.layout.ingest_batch_ingester_dir( event.ingest_run_srn, event.batch_index ) - ingester_records = _read_ingester_records(ingester_dir) + ingester_records = IngesterRecord.from_jsonl(ingester_dir / "records.jsonl") # Read hook outcomes for all hooks expected_features = [h.name for h in convention.hooks] @@ -68,9 +68,11 @@ async def handle(self, event: HookBatchCompleted) -> None: total = len(ingester_records) for hook_name in expected_features: outcomes = await self.feature_storage.read_batch_outcomes(str(batch_dir), hook_name) - passed = sum(1 for o in outcomes.values() if o.status == "passed") - rejected = sum(1 for o in outcomes.values() if o.status == "rejected") - errored = sum(1 for o in outcomes.values() if o.status == "errored") + from osa.domain.validation.model.batch_outcome import OutcomeStatus + + passed = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.PASSED) + rejected = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.REJECTED) + errored = sum(1 for o in outcomes.values() if o.status == OutcomeStatus.ERRORED) missing = total - len(outcomes) short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] log.info( @@ -174,33 +176,6 @@ async def handle(self, event: HookBatchCompleted) -> None: ) -def _read_ingester_records(ingester_dir) -> list[IngesterRecord]: - """Read ingester records from JSONL file into typed objects.""" - import json - from pathlib import Path - - records_file = Path(ingester_dir) / "records.jsonl" - records: list[IngesterRecord] = [] - if not records_file.exists(): - return records - for line in records_file.open(): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - records.append( - IngesterRecord( - source_id=data.get("source_id", data.get("id", "")), - metadata=data.get("metadata", {}), - file_paths=data.get("file_paths", []), - ) - ) - except (json.JSONDecodeError, ValueError): - log.warn("Skipping malformed ingester record line") - return records - - async def _get_passed_records( ingester_records: list[IngesterRecord], batch_dir: str, @@ -217,7 +192,9 @@ async def _get_passed_records( outcomes = await feature_storage.read_batch_outcomes(batch_dir, hook_name) if not outcomes: return [] - hook_passed = {rid for rid, o in outcomes.items() if o.status == "passed"} + from osa.domain.validation.model.batch_outcome import OutcomeStatus + + hook_passed = {rid for rid, o in outcomes.items() if o.status == OutcomeStatus.PASSED} if passed_ids is None: passed_ids = hook_passed diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index 69979a7..4681962 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -1,6 +1,5 @@ """RunHooks — runs hook containers on an ingester batch.""" -import json from pathlib import Path from uuid import uuid4 @@ -13,7 +12,8 @@ from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox from osa.domain.validation.model.hook_input import HookRecord -from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.domain.validation.port.hook_runner import HookInputs +from osa.domain.validation.service.hook import HookService from osa.infrastructure.logging import get_logger from osa.infrastructure.storage.layout import StorageLayout @@ -27,7 +27,7 @@ class RunHooks(EventHandler[IngesterBatchReady]): ingest_repo: IngestRunRepository convention_service: ConventionService - hook_runner: HookRunner + hook_service: HookService outbox: Outbox layout: StorageLayout @@ -44,9 +44,7 @@ async def handle(self, event: IngesterBatchReady) -> None: ingester_dir = self.layout.ingest_batch_ingester_dir( event.ingest_run_srn, event.batch_index ) - records_file = ingester_dir / "records.jsonl" - - records = _read_ingester_records(records_file) + records = IngesterRecord.from_jsonl(ingester_dir / "records.jsonl") if not records: log.warn( @@ -64,28 +62,43 @@ async def handle(self, event: IngesterBatchReady) -> None: if record_files.exists(): files_dirs[record.source_id] = record_files - # Run each hook sequentially + # Convert to HookInputs with size hints and file dirs + inputs = HookInputs( + records=[ + HookRecord( + id=r.source_id, + metadata=r.metadata, + size_hint_mb=r.total_file_mb, + ) + for r in records + ], + run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", + files_dirs=files_dirs, + ) + + # Build work_dirs for each hook + work_dirs: dict[str, Path] = {} for hook in convention.hooks: - hook_output_dir = self.layout.ingest_batch_hook_dir( + hook_dir = self.layout.ingest_batch_hook_dir( event.ingest_run_srn, event.batch_index, hook.name ) - hook_output_dir.mkdir(parents=True, exist_ok=True) - - inputs = HookInputs( - records=[HookRecord(id=r.source_id, metadata=r.metadata) for r in records], - run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", - files_dirs=files_dirs, - config=None, - ) - - result = await self.hook_runner.run(hook, inputs, hook_output_dir) + hook_dir.mkdir(parents=True, exist_ok=True) + work_dirs[hook.name] = hook_dir + + # Run all hooks via HookService + results = await self.hook_service.run_hooks_for_batch( + hooks=convention.hooks, + inputs=inputs, + work_dirs=work_dirs, + ) - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + for result in results: log.info( "[{short_id}] batch {batch_index} hook={hook_name}: {status} in {duration:.1f}s", short_id=short_id, batch_index=event.batch_index, - hook_name=hook.name, + hook_name=result.hook_name, status=result.status.value, duration=result.duration_seconds, ingest_run_srn=event.ingest_run_srn, @@ -99,26 +112,3 @@ async def handle(self, event: IngesterBatchReady) -> None: batch_index=event.batch_index, ) ) - - -def _read_ingester_records(records_file: Path) -> list[IngesterRecord]: - """Read ingester records from JSONL file into typed objects.""" - records: list[IngesterRecord] = [] - if not records_file.exists(): - return records - for line in records_file.open(): - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - records.append( - IngesterRecord( - source_id=data.get("source_id", data.get("id", "")), - metadata=data.get("metadata", {}), - file_paths=data.get("file_paths", []), - ) - ) - except (json.JSONDecodeError, ValueError): - log.warn("Skipping malformed ingester record line") - return records diff --git a/server/osa/domain/ingest/model/ingester_record.py b/server/osa/domain/ingest/model/ingester_record.py index 110d976..ebb7512 100644 --- a/server/osa/domain/ingest/model/ingester_record.py +++ b/server/osa/domain/ingest/model/ingester_record.py @@ -1,9 +1,22 @@ """IngesterRecord — typed representation of a record from an ingester container.""" +import json +import logging +from pathlib import Path from typing import Any from osa.domain.shared.model.value import ValueObject +logger = logging.getLogger(__name__) + + +class IngesterFileRef(ValueObject): + """A reference to a file produced by an ingester container.""" + + name: str + relative_path: str + size_mb: float + class IngesterRecord(ValueObject): """A record produced by an ingester container, parsed from records.jsonl. @@ -14,4 +27,34 @@ class IngesterRecord(ValueObject): source_id: str metadata: dict[str, Any] - file_paths: list[str] = [] + files: list[IngesterFileRef] = [] + + @property + def total_file_mb(self) -> float: + """Sum of all file sizes in megabytes.""" + return sum(f.size_mb for f in self.files) + + @classmethod + def from_jsonl(cls, path: Path) -> list["IngesterRecord"]: + """Parse records.jsonl into typed IngesterRecord objects.""" + records: list[IngesterRecord] = [] + if not path.exists(): + return records + for line in path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + files_raw = data.get("files", []) + files = [IngesterFileRef.model_validate(f) for f in files_raw] + records.append( + IngesterRecord( + source_id=data.get("source_id", data.get("id", "")), + metadata=data.get("metadata", {}), + files=files, + ) + ) + except (json.JSONDecodeError, ValueError): + logger.warning("Skipping malformed ingester record line") + return records diff --git a/server/osa/domain/shared/model/hook.py b/server/osa/domain/shared/model/hook.py index 1a90db7..346ba34 100644 --- a/server/osa/domain/shared/model/hook.py +++ b/server/osa/domain/shared/model/hook.py @@ -5,6 +5,7 @@ (NextflowConfig, TimeSeriesFeatureSpec, …) slot in without touching existing code. """ +import re from typing import Annotated, Any, Literal from pydantic import Field @@ -15,6 +16,45 @@ # Safe for use as PG identifiers, file path components, and env var values. PgIdentifier = Annotated[str, Field(pattern=r"^[a-z][a-z0-9_]{0,62}$")] +_MEMORY_RE = re.compile(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$") + +_GIB = 1024 * 1024 * 1024 +_MIB = 1024 * 1024 +_KIB = 1024 + + +def parse_memory(memory: str) -> int: + """Parse memory string like '2g' or '512m' to bytes.""" + match = _MEMORY_RE.match(memory.strip().lower()) + if not match: + raise ValueError(f"Invalid memory format: {memory}") + + amount = float(match.group(1)) + unit = match.group(2) + + match unit: + case "g": + return int(amount * _GIB) + case "m": + return int(amount * _MIB) + case "k": + return int(amount * _KIB) + case None: + return int(amount) + case _: + raise ValueError(f"Unknown memory unit: {unit}") + + +def _format_memory(byte_count: int) -> str: + """Format bytes to a compact memory string (e.g. '2g', '1536m').""" + if byte_count % _GIB == 0: + return f"{byte_count // _GIB}g" + if byte_count % _MIB == 0: + return f"{byte_count // _MIB}m" + if byte_count % _KIB == 0: + return f"{byte_count // _KIB}k" + return str(byte_count) + class ColumnDef(ValueObject): """Definition of a single column in a feature table.""" @@ -78,3 +118,15 @@ class HookDefinition(ValueObject): name: PgIdentifier runtime: Annotated[OciConfig, Field(discriminator="type")] feature: Annotated[TableFeatureSpec, Field(discriminator="kind")] + + def with_memory(self, memory: str) -> "HookDefinition": + """Return a copy with a different memory limit.""" + new_limits = self.runtime.limits.model_copy(update={"memory": memory}) + new_runtime = self.runtime.model_copy(update={"limits": new_limits}) + return self.model_copy(update={"runtime": new_runtime}) + + def with_doubled_memory(self) -> "HookDefinition": + """Return a copy with 2x the current memory limit.""" + current_bytes = parse_memory(self.runtime.limits.memory) + doubled = _format_memory(current_bytes * 2) + return self.with_memory(doubled) diff --git a/server/osa/domain/validation/model/batch_outcome.py b/server/osa/domain/validation/model/batch_outcome.py index 4206a02..93f8e0d 100644 --- a/server/osa/domain/validation/model/batch_outcome.py +++ b/server/osa/domain/validation/model/batch_outcome.py @@ -1,9 +1,20 @@ """Per-record outcome from a batch hook run.""" -from typing import Any +from enum import StrEnum +from typing import Any, NewType from osa.domain.shared.model.value import ValueObject +HookRecordId = NewType("HookRecordId", str) + + +class OutcomeStatus(StrEnum): + """Outcome status for a single record in a batch hook execution.""" + + PASSED = "passed" + REJECTED = "rejected" + ERRORED = "errored" + class BatchRecordOutcome(ValueObject): """Per-record outcome from a batch hook execution. @@ -12,8 +23,8 @@ class BatchRecordOutcome(ValueObject): passed (with features), rejected (with reason), or errored. """ - record_id: str - status: str # "passed", "rejected", "errored" + record_id: HookRecordId + status: OutcomeStatus features: list[dict[str, Any]] = [] reason: str | None = None error: str | None = None diff --git a/server/osa/domain/validation/model/hook_input.py b/server/osa/domain/validation/model/hook_input.py index 8c07075..085c74c 100644 --- a/server/osa/domain/validation/model/hook_input.py +++ b/server/osa/domain/validation/model/hook_input.py @@ -13,3 +13,4 @@ class HookRecord(ValueObject): id: str metadata: dict[str, Any] + size_hint_mb: float = 0 diff --git a/server/osa/domain/validation/model/hook_result.py b/server/osa/domain/validation/model/hook_result.py index bd71fcd..ba4b959 100644 --- a/server/osa/domain/validation/model/hook_result.py +++ b/server/osa/domain/validation/model/hook_result.py @@ -11,6 +11,7 @@ class HookStatus(StrEnum): PASSED = "passed" REJECTED = "rejected" FAILED = "failed" + OOM = "oom" class ProgressEntry(ValueObject): @@ -30,3 +31,8 @@ class HookResult(ValueObject): error_message: str | None = None progress: list[ProgressEntry] = Field(default_factory=list) duration_seconds: float + + @property + def oom_killed(self) -> bool: + """Whether this hook was killed by an out-of-memory condition.""" + return self.status == HookStatus.OOM diff --git a/server/osa/domain/validation/port/storage.py b/server/osa/domain/validation/port/storage.py index 353578b..eb6d34e 100644 --- a/server/osa/domain/validation/port/storage.py +++ b/server/osa/domain/validation/port/storage.py @@ -6,6 +6,7 @@ from osa.domain.shared.model.srn import DepositionSRN from osa.domain.shared.port import Port +from osa.domain.validation.model.batch_outcome import BatchRecordOutcome, HookRecordId class HookStoragePort(Port, Protocol): @@ -20,3 +21,26 @@ def get_hook_output_dir(self, deposition_srn: DepositionSRN, hook_name: str) -> def get_files_dir(self, deposition_id: DepositionSRN) -> Path: """Return the directory containing data files for a deposition.""" ... + + @abstractmethod + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + """Atomically write checkpoint JSONL to work_dir/_checkpoint.jsonl.""" + ... + + @abstractmethod + def write_batch_outcomes( + self, + work_dir: Path, + outcomes: dict[HookRecordId, BatchRecordOutcome], + ) -> None: + """Write canonical features.jsonl, rejections.jsonl, errors.jsonl.""" + ... + + @abstractmethod + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + """Read JSONL batch outputs (features/rejections/errors) for a hook.""" + ... diff --git a/server/osa/domain/validation/service/hook.py b/server/osa/domain/validation/service/hook.py new file mode 100644 index 0000000..55d2749 --- /dev/null +++ b/server/osa/domain/validation/service/hook.py @@ -0,0 +1,239 @@ +"""HookService — executes hooks with OOM retry and checkpointing. + +Handles both single-record (deposition) and multi-record (ingestion) batches. +On OOM, retries with doubled memory up to MAX_OOM_RETRIES times. +Checkpoints partial progress so crash recovery skips completed records. + +Sorting assumption: hooks process records in input order and write output +incrementally (features.jsonl line by line). Sorting by file size ascending +maximizes checkpoint progress before a potential OOM on a large record. +""" + +import json +from collections.abc import Iterable +from pathlib import Path + +from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.service import Service +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.domain.validation.port.storage import HookStoragePort +from osa.infrastructure.logging import get_logger + +log = get_logger(__name__) + +MAX_OOM_RETRIES = 3 + + +class HookService(Service): + """Executes a hook with OOM retry, checkpointing, and finalization.""" + + hook_runner: HookRunner + hook_storage: HookStoragePort + + async def run_hook( + self, + hook: HookDefinition, + inputs: HookInputs, + work_dir: Path, + ) -> HookResult: + """Run a single hook against a batch of records, retrying on OOM. + + Returns the final HookResult. On success or non-OOM failure, returns + after the first attempt. On OOM, retries with doubled memory up to + MAX_OOM_RETRIES times, checkpointing partial progress between attempts. + """ + records = inputs.records + if not records: + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=0.0, + ) + + # Load checkpoint (crash recovery) + outcomes = _load_checkpoint(work_dir) + remaining = _sort_by_size(r for r in records if r.id not in outcomes) + + if not remaining: + # All records already checkpointed + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=0.0, + ) + + current_hook = hook + total_duration = 0.0 + + for attempt in range(1 + MAX_OOM_RETRIES): + attempt_inputs = HookInputs( + records=remaining, + run_id=inputs.run_id, + files_dirs=inputs.files_dirs, + config=inputs.config, + ) + + result = await self.hook_runner.run(current_hook, attempt_inputs, work_dir) + total_duration += result.duration_seconds + + # Read any output written by this attempt + new_outcomes = _read_output_dir(work_dir) + for rid, outcome in new_outcomes.items(): + if rid not in outcomes: + outcomes[rid] = outcome + + if result.oom_killed: + # Checkpoint what we have so far + self.hook_storage.write_checkpoint(work_dir, outcomes) + + remaining = _sort_by_size(r for r in records if r.id not in outcomes) + if not remaining: + break + + if attempt < MAX_OOM_RETRIES: + current_hook = current_hook.with_doubled_memory() + log.info( + "OOM retry {attempt}/{max_retries} for hook={hook_name}, memory={memory}, remaining={remaining} records", + attempt=attempt + 1, + max_retries=MAX_OOM_RETRIES, + hook_name=hook.name, + memory=current_hook.runtime.limits.memory, + remaining=len(remaining), + ) + continue + else: + # Exhausted retries — mark remaining as errored + for r in remaining: + outcomes[HookRecordId(r.id)] = BatchRecordOutcome( + record_id=HookRecordId(r.id), + status=OutcomeStatus.ERRORED, + error=f"OOM after {MAX_OOM_RETRIES} retries (last limit: {current_hook.runtime.limits.memory})", + ) + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.OOM, + error_message=f"OOM exhausted after {MAX_OOM_RETRIES} retries", + duration_seconds=total_duration, + ) + elif result.status == HookStatus.FAILED: + # Non-OOM failure — no retry + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return result + elif result.status == HookStatus.REJECTED: + # Rejection — no retry, propagate status + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + return HookResult( + hook_name=hook.name, + status=HookStatus.REJECTED, + rejection_reason=result.rejection_reason, + duration_seconds=total_duration, + ) + else: + # Success (PASSED) + break + + # Finalize: write canonical output files + self.hook_storage.write_batch_outcomes(work_dir, outcomes) + _cleanup_checkpoint(work_dir) + + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=total_duration, + ) + + async def run_hooks_for_batch( + self, + hooks: list[HookDefinition], + inputs: HookInputs, + work_dirs: dict[str, Path], + ) -> list[HookResult]: + """Run multiple hooks sequentially for a batch of records. + + work_dirs maps hook_name → output directory. + """ + results: list[HookResult] = [] + for hook in hooks: + work_dir = work_dirs[hook.name] + result = await self.run_hook(hook, inputs, work_dir) + results.append(result) + return results + + +def _sort_by_size(records: Iterable[HookRecord]) -> list[HookRecord]: + """Sort records by size_hint_mb ascending. Skip sort when all sizes are 0.""" + record_list = list(records) + if any(r.size_hint_mb > 0 for r in record_list): + return sorted(record_list, key=lambda r: r.size_hint_mb) + return record_list + + +def _load_checkpoint(work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + """Load checkpoint from _checkpoint.jsonl. Returns empty dict on missing/corrupt.""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + if not checkpoint_path.exists(): + return {} + + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + for line in checkpoint_path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + outcome = BatchRecordOutcome.model_validate(data) + outcomes[outcome.record_id] = outcome + except (json.JSONDecodeError, ValueError): + log.warn("Skipping malformed checkpoint line") + continue + + return outcomes + + +def _read_output_dir(work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + """Read hook output files (features.jsonl, rejections.jsonl, errors.jsonl).""" + output_dir = work_dir / "output" + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} + + for filename, status, field_map in [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), + ]: + path = output_dir / filename + if not path.exists(): + continue + for line in path.open(): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + raw_id = data.get("id") + if not raw_id: + continue + record_id = HookRecordId(raw_id) + kwargs: dict = {"record_id": record_id, "status": status} + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) + + return outcomes + + +def _cleanup_checkpoint(work_dir: Path) -> None: + """Remove checkpoint file after successful finalization.""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + checkpoint_path.unlink(missing_ok=True) diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index 3a37a9c..a2dcdee 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -23,6 +23,7 @@ from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.domain.validation.port.repository import ValidationRunRepository from osa.domain.validation.port.storage import HookStoragePort +from osa.domain.validation.service.hook import HookService logger = logging.getLogger(__name__) @@ -64,24 +65,30 @@ async def run_hooks( inputs: HookInputs, hooks: list[HookDefinition], ) -> tuple[ValidationRun, list[HookResult]]: - """Execute hooks sequentially. Halt on reject/fail. + """Execute hooks sequentially with OOM retry. Halt on reject/fail/OOM. Hook outputs are written to durable cold storage under the deposition directory. Feature insertion is deferred to record publication time. + Each hook is executed via HookService which handles OOM retry with memory doubling. """ run.status = RunStatus.RUNNING run.started_at = datetime.now(timezone.utc) await self.run_repo.save(run) + hook_service = HookService( + hook_runner=self.hook_runner, + hook_storage=self.hook_storage, + ) + hook_results: list[HookResult] = [] overall_status: RunStatus = RunStatus.COMPLETED for hook in hooks: work_dir = self.hook_storage.get_hook_output_dir(deposition_srn, hook.name) - result = await self.hook_runner.run(hook, inputs, work_dir) + result = await hook_service.run_hook(hook, inputs, work_dir) hook_results.append(result) - if result.status == HookStatus.FAILED: + if result.status in (HookStatus.FAILED, HookStatus.OOM): overall_status = RunStatus.FAILED break if result.status == HookStatus.REJECTED: diff --git a/server/osa/domain/validation/util/di/provider.py b/server/osa/domain/validation/util/di/provider.py index 1a9235f..707675d 100644 --- a/server/osa/domain/validation/util/di/provider.py +++ b/server/osa/domain/validation/util/di/provider.py @@ -3,12 +3,14 @@ from osa.config import Config from osa.domain.shared.model.srn import Domain from osa.domain.validation.service import ValidationService +from osa.domain.validation.service.hook import HookService from osa.util.di.base import Provider from osa.util.di.scope import Scope class ValidationProvider(Provider): service = provide(ValidationService, scope=Scope.UOW) + hook_service = provide(HookService, scope=Scope.UOW) @provide(scope=Scope.UOW) def get_node_domain(self, config: Config) -> Domain: diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 1f0a040..6d5ca8a 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -477,7 +477,7 @@ async def _diagnose_failure( if getattr(terminated, "reason", None) == "OOMKilled": return HookResult( hook_name=hook.name, - status=HookStatus.FAILED, + status=HookStatus.OOM, error_message="Hook killed by OOM", duration_seconds=duration, ) diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index 362ceb9..05414d3 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -192,7 +192,7 @@ async def _run_container( for line in tail_text.splitlines(): print(f" OOM [{hook.name}] {line}", file=sys.stderr, flush=True) return { - "status": HookStatus.FAILED, + "status": HookStatus.OOM, "error_message": f"Hook killed by OOM (limit: {hook.runtime.limits.memory})", } diff --git a/server/osa/infrastructure/persistence/adapter/storage.py b/server/osa/infrastructure/persistence/adapter/storage.py index e45b869..f1e4dcd 100644 --- a/server/osa/infrastructure/persistence/adapter/storage.py +++ b/server/osa/infrastructure/persistence/adapter/storage.py @@ -1,6 +1,7 @@ import hashlib import json import logging +import os import shutil import tempfile from collections.abc import AsyncIterator @@ -12,7 +13,11 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.validation.model.batch_outcome import BatchRecordOutcome +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) logger = logging.getLogger(__name__) @@ -195,40 +200,99 @@ async def move_source_files_to_deposition( async def read_batch_outcomes( self, output_dir: str, hook_name: str - ) -> dict[str, BatchRecordOutcome]: + ) -> dict[HookRecordId, BatchRecordOutcome]: """Read JSONL batch outputs from the filesystem, streaming line-by-line.""" hook_output = Path(output_dir) / "hooks" / hook_name / "output" - outcomes: dict[str, BatchRecordOutcome] = {} + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} - for filename, status_key, field_map in [ - ("features.jsonl", "passed", {"features": "features"}), - ("rejections.jsonl", "rejected", {"reason": "reason"}), - ("errors.jsonl", "errored", {"error": "error", "retryable": "retryable"}), - ]: - path = hook_output / filename - if not path.exists(): - continue - with path.open() as f: - for line in f: - line = line.strip() - if not line: - continue - try: - data = json.loads(line) - except json.JSONDecodeError: - logger.warning("Skipping malformed JSON line in %s", filename) - continue - record_id = data.get("id") - if not record_id: - logger.warning("Skipping JSONL line without 'id' in %s", filename) - continue - kwargs: dict[str, Any] = { - "record_id": record_id, - "status": status_key, - } - for src, dst in field_map.items(): - if src in data: - kwargs[dst] = data[src] - outcomes[record_id] = BatchRecordOutcome(**kwargs) + _parse_batch_output_files(hook_output, outcomes) return outcomes + + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + """Atomically write checkpoint JSONL via os.replace().""" + checkpoint_path = work_dir / "_checkpoint.jsonl" + tmp_path = work_dir / "_checkpoint.jsonl.tmp" + with tmp_path.open("w") as f: + for outcome in outcomes.values(): + f.write(outcome.model_dump_json() + "\n") + os.replace(tmp_path, checkpoint_path) + + def write_batch_outcomes( + self, + work_dir: Path, + outcomes: dict[HookRecordId, BatchRecordOutcome], + ) -> None: + """Write canonical features.jsonl, rejections.jsonl, errors.jsonl.""" + output_dir = work_dir / "output" + output_dir.mkdir(parents=True, exist_ok=True) + + features: list[str] = [] + rejections: list[str] = [] + errors: list[str] = [] + + for outcome in outcomes.values(): + row: dict[str, Any] = {"id": outcome.record_id} + if outcome.status == OutcomeStatus.PASSED: + row["features"] = outcome.features + features.append(json.dumps(row)) + elif outcome.status == OutcomeStatus.REJECTED: + row["reason"] = outcome.reason + rejections.append(json.dumps(row)) + elif outcome.status == OutcomeStatus.ERRORED: + row["error"] = outcome.error + row["retryable"] = outcome.retryable + errors.append(json.dumps(row)) + + for filename, lines in [ + ("features.jsonl", features), + ("rejections.jsonl", rejections), + ("errors.jsonl", errors), + ]: + if lines: + (output_dir / filename).write_text("\n".join(lines) + "\n") + + +# ── Shared parsing ────────────────────────────────────────────────────── + + +_FILE_STATUS_MAP: list[tuple[str, OutcomeStatus, dict[str, str]]] = [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), +] + + +def _parse_batch_output_files( + output_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] +) -> None: + """Parse features/rejections/errors JSONL files into BatchRecordOutcome dict.""" + for filename, status, field_map in _FILE_STATUS_MAP: + path = output_dir / filename + if not path.exists(): + continue + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed JSON line in %s", filename) + continue + raw_id = data.get("id") + if not raw_id: + logger.warning("Skipping JSONL line without 'id' in %s", filename) + continue + record_id = HookRecordId(raw_id) + kwargs: dict[str, Any] = { + "record_id": record_id, + "status": status, + } + for src, dst in field_map.items(): + if src in data: + kwargs[dst] = data[src] + outcomes[record_id] = BatchRecordOutcome(**kwargs) diff --git a/server/osa/infrastructure/runner_utils.py b/server/osa/infrastructure/runner_utils.py index bf58525..b7340be 100644 --- a/server/osa/infrastructure/runner_utils.py +++ b/server/osa/infrastructure/runner_utils.py @@ -49,26 +49,14 @@ def detect_rejection(progress: list[ProgressEntry]) -> tuple[bool, str | None]: def parse_memory(memory: str) -> int: - """Parse memory string like '2g' or '512m' to bytes.""" - memory = memory.strip().lower() - match = re.match(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$", memory) - if not match: - raise ValueError(f"Invalid memory format: {memory}") + """Parse memory string like '2g' or '512m' to bytes. - amount = float(match.group(1)) - unit = match.group(2) + .. deprecated:: + Use ``osa.domain.shared.model.hook.parse_memory`` instead. + """ + from osa.domain.shared.model.hook import parse_memory as _parse_memory - match unit: - case "g": - return int(amount * 1024 * 1024 * 1024) - case "m": - return int(amount * 1024 * 1024) - case "k": - return int(amount * 1024) - case None: - return int(amount) - case _: - raise ValueError(f"Unknown memory unit: {unit}") + return _parse_memory(memory) _MEMORY_RE = re.compile(r"^(\d+(?:\.\d+)?)(g|m|k)?i?$") diff --git a/server/osa/infrastructure/s3/storage.py b/server/osa/infrastructure/s3/storage.py index 5b340fa..fc1ebd8 100644 --- a/server/osa/infrastructure/s3/storage.py +++ b/server/osa/infrastructure/s3/storage.py @@ -16,7 +16,11 @@ from osa.domain.deposition.port.storage import FileStoragePort from osa.domain.shared.error import InfrastructureError, NotFoundError from osa.domain.shared.model.srn import ConventionSRN, DepositionSRN -from osa.domain.validation.model.batch_outcome import BatchRecordOutcome +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) from osa.infrastructure.runner_utils import relative_path from osa.infrastructure.s3.client import S3Client @@ -206,17 +210,19 @@ async def hook_features_exist(self, hook_output_dir: str, feature_name: str) -> async def read_batch_outcomes( self, output_dir: str, hook_name: str - ) -> dict[str, BatchRecordOutcome]: + ) -> dict[HookRecordId, BatchRecordOutcome]: """Read JSONL batch outputs from S3.""" prefix = relative_path(Path(output_dir), self._data_mount_path) hook_prefix = f"{prefix}/hooks/{hook_name}/output" - outcomes: dict[str, BatchRecordOutcome] = {} + outcomes: dict[HookRecordId, BatchRecordOutcome] = {} - for filename, status_key, field_map in [ - ("features.jsonl", "passed", {"features": "features"}), - ("rejections.jsonl", "rejected", {"reason": "reason"}), - ("errors.jsonl", "errored", {"error": "error", "retryable": "retryable"}), - ]: + file_status_map: list[tuple[str, OutcomeStatus, dict[str, str]]] = [ + ("features.jsonl", OutcomeStatus.PASSED, {"features": "features"}), + ("rejections.jsonl", OutcomeStatus.REJECTED, {"reason": "reason"}), + ("errors.jsonl", OutcomeStatus.ERRORED, {"error": "error", "retryable": "retryable"}), + ] + + for filename, status, field_map in file_status_map: key = f"{hook_prefix}/{filename}" try: data_bytes = await self._s3.get_object(key) @@ -232,13 +238,14 @@ async def read_batch_outcomes( except json.JSONDecodeError: logger.warning("Skipping malformed JSON line in %s", filename) continue - record_id = data.get("id") - if not record_id: + raw_id = data.get("id") + if not raw_id: logger.warning("Skipping JSONL line without 'id' in %s", filename) continue + record_id = HookRecordId(raw_id) kwargs: dict[str, Any] = { "record_id": record_id, - "status": status_key, + "status": status, } for src, dst in field_map.items(): if src in data: diff --git a/server/tests/unit/domain/ingest/test_ingester_record.py b/server/tests/unit/domain/ingest/test_ingester_record.py new file mode 100644 index 0000000..59ca361 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_ingester_record.py @@ -0,0 +1,82 @@ +"""Tests for IngesterRecord model — from_jsonl parsing and IngesterFileRef.""" + +import json +from pathlib import Path + + +def test_from_jsonl_happy_path(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text( + json.dumps( + { + "source_id": "rec1", + "metadata": {"title": "Test"}, + "files": [{"name": "f.cif", "relative_path": "rec1/f.cif", "size_mb": 10.5}], + } + ) + + "\n" + + json.dumps({"source_id": "rec2", "metadata": {"title": "Test 2"}}) + + "\n" + ) + + records = IngesterRecord.from_jsonl(records_file) + assert len(records) == 2 + assert records[0].source_id == "rec1" + assert records[0].metadata == {"title": "Test"} + assert len(records[0].files) == 1 + assert records[0].files[0].name == "f.cif" + assert records[0].files[0].size_mb == 10.5 + assert records[1].source_id == "rec2" + assert records[1].files == [] + + +def test_from_jsonl_malformed_lines_skipped(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text( + "NOT VALID JSON\n" + json.dumps({"source_id": "good", "metadata": {}}) + "\n" + "{broken\n" + ) + + records = IngesterRecord.from_jsonl(records_file) + assert len(records) == 1 + assert records[0].source_id == "good" + + +def test_from_jsonl_empty_file(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records_file = tmp_path / "records.jsonl" + records_file.write_text("") + records = IngesterRecord.from_jsonl(records_file) + assert records == [] + + +def test_from_jsonl_nonexistent_file(tmp_path: Path): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + records = IngesterRecord.from_jsonl(tmp_path / "does_not_exist.jsonl") + assert records == [] + + +def test_total_file_mb_property(): + from osa.domain.ingest.model.ingester_record import IngesterFileRef, IngesterRecord + + record = IngesterRecord( + source_id="rec1", + metadata={}, + files=[ + IngesterFileRef(name="a.cif", relative_path="rec1/a.cif", size_mb=10.0), + IngesterFileRef(name="b.cif", relative_path="rec1/b.cif", size_mb=28.5), + ], + ) + assert record.total_file_mb == 38.5 + + +def test_total_file_mb_empty(): + from osa.domain.ingest.model.ingester_record import IngesterRecord + + record = IngesterRecord(source_id="rec1", metadata={}) + assert record.total_file_mb == 0 diff --git a/server/tests/unit/domain/shared/test_hook_models.py b/server/tests/unit/domain/shared/test_hook_models.py index cd44924..cf669b0 100644 --- a/server/tests/unit/domain/shared/test_hook_models.py +++ b/server/tests/unit/domain/shared/test_hook_models.py @@ -211,6 +211,60 @@ def test_hook_definition_serialization_roundtrip(): assert restored.feature.columns[1].required is False +class TestMemoryDoubling: + """Tests for HookDefinition.with_memory() and with_doubled_memory().""" + + def _make_hook(self, memory: str = "1g"): + from osa.domain.shared.model.hook import ( + HookDefinition, + OciConfig, + OciLimits, + TableFeatureSpec, + ) + + return HookDefinition( + name="detect_pockets", + runtime=OciConfig( + image="img:v1", + digest="sha256:abc", + limits=OciLimits(memory=memory), + ), + feature=TableFeatureSpec(cardinality="one", columns=[]), + ) + + def test_hook_definition_with_memory(self): + hook = self._make_hook("1g") + updated = hook.with_memory("2g") + assert updated.runtime.limits.memory == "2g" + # original unchanged (frozen) + assert hook.runtime.limits.memory == "1g" + + def test_hook_definition_with_doubled_memory_1g(self): + hook = self._make_hook("1g") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "2g" + + def test_hook_definition_with_doubled_memory_512m(self): + hook = self._make_hook("512m") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "1g" + + def test_hook_definition_with_doubled_memory_768m(self): + hook = self._make_hook("768m") + doubled = hook.with_doubled_memory() + assert doubled.runtime.limits.memory == "1536m" + + def test_hook_definition_with_doubled_memory_preserves_other_fields(self): + hook = self._make_hook("1g") + doubled = hook.with_doubled_memory() + assert doubled.name == hook.name + assert doubled.runtime.image == hook.runtime.image + assert doubled.runtime.digest == hook.runtime.digest + assert doubled.runtime.limits.timeout_seconds == hook.runtime.limits.timeout_seconds + assert doubled.runtime.limits.cpu == hook.runtime.limits.cpu + assert doubled.feature == hook.feature + + class TestNameValidation: """Hook and column names must be safe PG identifiers.""" diff --git a/server/tests/unit/domain/validation/test_hook_result.py b/server/tests/unit/domain/validation/test_hook_result.py index 8d2ba3d..8465ff5 100644 --- a/server/tests/unit/domain/validation/test_hook_result.py +++ b/server/tests/unit/domain/validation/test_hook_result.py @@ -105,6 +105,44 @@ def test_hook_result_default_progress_empty(): assert result.progress == [] +def test_hook_status_oom_value(): + from osa.domain.validation.model.hook_result import HookStatus + + assert HookStatus.OOM == "oom" + assert HookStatus.OOM.value == "oom" + + +def test_hook_result_oom_killed_true(): + from osa.domain.validation.model.hook_result import HookResult, HookStatus + + result = HookResult( + hook_name="detect_pockets", + status=HookStatus.OOM, + error_message="Hook killed by OOM (limit: 1g)", + duration_seconds=30.0, + ) + assert result.oom_killed is True + + +def test_hook_result_oom_killed_false(): + from osa.domain.validation.model.hook_result import HookResult, HookStatus + + result = HookResult( + hook_name="detect_pockets", + status=HookStatus.FAILED, + error_message="Some other error", + duration_seconds=10.0, + ) + assert result.oom_killed is False + + passed = HookResult( + hook_name="detect_pockets", + status=HookStatus.PASSED, + duration_seconds=5.0, + ) + assert passed.oom_killed is False + + def test_hook_result_serialization_roundtrip(): from osa.domain.validation.model.hook_result import ( HookResult, diff --git a/server/tests/unit/domain/validation/test_hook_service.py b/server/tests/unit/domain/validation/test_hook_service.py new file mode 100644 index 0000000..db45e46 --- /dev/null +++ b/server/tests/unit/domain/validation/test_hook_service.py @@ -0,0 +1,567 @@ +"""Tests for HookService — OOM retry with checkpointing.""" + +from pathlib import Path +from typing import Any +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.shared.model.hook import ( + ColumnDef, + HookDefinition, + OciConfig, + OciLimits, + TableFeatureSpec, +) +from osa.domain.validation.model.batch_outcome import ( + BatchRecordOutcome, + HookRecordId, + OutcomeStatus, +) +from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs + + +def _make_hook(name: str = "detect_pockets", memory: str = "1g") -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig( + image="img:v1", + digest="sha256:abc", + limits=OciLimits(memory=memory), + ), + feature=TableFeatureSpec( + cardinality="one", + columns=[ColumnDef(name="score", json_type="number", required=True)], + ), + ) + + +def _inputs(records: list[HookRecord]) -> HookInputs: + return HookInputs(records=records, run_id="test-run") + + +def _make_records(count: int = 3) -> list[HookRecord]: + return [HookRecord(id=f"rec{i}", metadata={"title": f"Record {i}"}) for i in range(count)] + + +def _passed_result(hook_name: str = "detect_pockets", duration: float = 5.0) -> HookResult: + return HookResult(hook_name=hook_name, status=HookStatus.PASSED, duration_seconds=duration) + + +def _oom_result(hook_name: str = "detect_pockets", duration: float = 30.0) -> HookResult: + return HookResult( + hook_name=hook_name, + status=HookStatus.OOM, + error_message="Hook killed by OOM", + duration_seconds=duration, + ) + + +def _failed_result(hook_name: str = "detect_pockets", duration: float = 10.0) -> HookResult: + return HookResult( + hook_name=hook_name, + status=HookStatus.FAILED, + error_message="Some error", + duration_seconds=duration, + ) + + +class FakeHookStorage: + """Fake HookStoragePort for testing — stores checkpoints and outcomes in memory.""" + + def __init__(self) -> None: + self.checkpoints: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + self.written_outcomes: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + self._batch_outcomes: dict[str, dict[HookRecordId, BatchRecordOutcome]] = {} + + def get_hook_output_dir(self, deposition_srn: Any, hook_name: str) -> Path: + return Path(f"/fake/hooks/{hook_name}") + + def get_files_dir(self, deposition_id: Any) -> Path: + return Path("/fake/files") + + def write_checkpoint( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + self.checkpoints[str(work_dir)] = dict(outcomes) + + def write_batch_outcomes( + self, work_dir: Path, outcomes: dict[HookRecordId, BatchRecordOutcome] + ) -> None: + self.written_outcomes[str(work_dir)] = dict(outcomes) + + async def read_batch_outcomes( + self, output_dir: str, hook_name: str + ) -> dict[HookRecordId, BatchRecordOutcome]: + key = f"{output_dir}/{hook_name}" + return self._batch_outcomes.get(key, {}) + + def read_checkpoint(self, work_dir: Path) -> dict[HookRecordId, BatchRecordOutcome]: + return self.checkpoints.get(str(work_dir), {}) + + +class TestHookServiceNoOOM: + """T015: No OOM — hook runs once, correct output.""" + + @pytest.mark.asyncio + async def test_no_oom_runs_once(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + runner = AsyncMock() + runner.run.return_value = _passed_result() + storage = FakeHookStorage() + + # Simulate runner writing outcomes (features.jsonl) + # After run, HookService reads output dir for outcomes + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + import json + + features_file = output_dir / "features.jsonl" + features_file.write_text( + "\n".join(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) for r in records) + + "\n" + ) + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() + + +class TestHookServiceOOMRetry: + """T016: OOM retry doubles memory.""" + + @pytest.mark.asyncio + async def test_oom_retry_doubles_memory(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook(memory="1g") + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + call_count = 0 + + async def mock_run(h, inputs, wd): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: write partial output then OOM + features_file = output_dir / "features.jsonl" + features_file.write_text( + json.dumps({"id": records[0].id, "features": [{"score": 0.5}]}) + "\n" + ) + return _oom_result() + else: + # Second call: succeed with remaining + features_file = output_dir / "features.jsonl" + # Append the second record + with features_file.open("a") as f: + f.write(json.dumps({"id": records[1].id, "features": [{"score": 0.8}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + assert runner.run.call_count == 2 + # Second call should have doubled memory + second_call_hook = runner.run.call_args_list[1][0][0] + assert second_call_hook.runtime.limits.memory == "2g" + + +class TestHookServiceOOMExhaustion: + """T017: OOM exhaustion marks remaining records as errored.""" + + @pytest.mark.asyncio + async def test_oom_exhaustion_marks_errored(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook(memory="1g") + records = _make_records(1) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + runner = AsyncMock() + runner.run.return_value = _oom_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.OOM + # Should have retried MAX_OOM_RETRIES times + assert runner.run.call_count == 4 # 1 initial + 3 retries + + # Check outcomes written with error + assert str(work_dir) in storage.written_outcomes + outcomes = storage.written_outcomes[str(work_dir)] + assert HookRecordId("rec0") in outcomes + assert outcomes[HookRecordId("rec0")].status == OutcomeStatus.ERRORED + assert "OOM" in (outcomes[HookRecordId("rec0")].error or "") + + +class TestHookServiceNonOOMFailure: + """T018: Non-OOM failure does NOT trigger retry.""" + + @pytest.mark.asyncio + async def test_non_oom_failure_no_retry(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(1) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + (work_dir / "output").mkdir(parents=True) + + runner = AsyncMock() + runner.run.return_value = _failed_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.FAILED + runner.run.assert_called_once() + + +class TestHookServiceFinalize: + """T019: Finalize writes canonical files.""" + + @pytest.mark.asyncio + async def test_finalize_writes_canonical_files(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + features_file = output_dir / "features.jsonl" + features_file.write_text( + "\n".join(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) for r in records) + + "\n" + ) + + runner = AsyncMock() + runner.run.return_value = _passed_result() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert str(work_dir) in storage.written_outcomes + outcomes = storage.written_outcomes[str(work_dir)] + assert len(outcomes) == 2 + for r in records: + assert HookRecordId(r.id) in outcomes + assert outcomes[HookRecordId(r.id)].status == OutcomeStatus.PASSED + + +class TestHookServiceEmptyRecords: + """T020: Empty records list — no container launched.""" + + @pytest.mark.asyncio + async def test_empty_records_noop(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + runner = AsyncMock() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs([]), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_not_called() + + +class TestHookServiceMultiHook: + """T021: Multi-hook — second OOMs, first not re-run.""" + + @pytest.mark.asyncio + async def test_multi_hook_second_ooms(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook1 = _make_hook(name="hook_one") + hook2 = _make_hook(name="hook_two", memory="512m") + records = _make_records(1) + + work_dir1 = tmp_path / "hook_one" + work_dir1.mkdir() + (work_dir1 / "output").mkdir(parents=True) + work_dir2 = tmp_path / "hook_two" + work_dir2.mkdir() + (work_dir2 / "output").mkdir(parents=True) + + import json + + # Hook 1 succeeds + (work_dir1 / "output" / "features.jsonl").write_text( + json.dumps({"id": "rec0", "features": [{"score": 0.9}]}) + "\n" + ) + + runner = AsyncMock() + storage = FakeHookStorage() + + call_index = 0 + + async def side_effect(h, inputs, wd): + nonlocal call_index + call_index += 1 + if h.name == "hook_one": + return _passed_result(hook_name="hook_one") + else: + return _oom_result(hook_name="hook_two") + + runner.run.side_effect = side_effect + + service = HookService(hook_runner=runner, hook_storage=storage) + + # Run hook 1 — should pass + r1 = await service.run_hook(hook1, _inputs(records), work_dir1) + assert r1.status == HookStatus.PASSED + + # Run hook 2 — should OOM and exhaust retries + r2 = await service.run_hook(hook2, _inputs(records), work_dir2) + assert r2.status == HookStatus.OOM + + # Hook 1 was called once, hook 2 was called 4 times (1 + 3 retries) + hook1_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_one"] + hook2_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_two"] + assert len(hook1_calls) == 1 + assert len(hook2_calls) == 4 + + +class TestHookServiceCheckpointRecovery: + """T022: Crash recovery from checkpoint — skips completed records.""" + + @pytest.mark.asyncio + async def test_checkpoint_crash_recovery(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(3) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + # Pre-populate checkpoint: rec0 already done + import json + + checkpoint_file = work_dir / "_checkpoint.jsonl" + checkpoint_file.write_text( + json.dumps( + { + "record_id": "rec0", + "status": "passed", + "features": [{"score": 0.5}], + } + ) + + "\n" + ) + + runner = AsyncMock() + storage = FakeHookStorage() + + async def mock_run(h, inputs, wd): + # Should only receive rec1 and rec2, not rec0 + input_ids = [r.id for r in inputs.records] + assert "rec0" not in input_ids + # Write output for remaining + features_file = output_dir / "features.jsonl" + with features_file.open("a") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner.run.side_effect = mock_run + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() + # Final outcomes should contain all 3 records + outcomes = storage.written_outcomes[str(work_dir)] + assert len(outcomes) == 3 + + +class TestHookServiceCheckpointAllComplete: + """T023: All records in checkpoint — hook never called.""" + + @pytest.mark.asyncio + async def test_checkpoint_all_complete_skips_hook(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + + import json + + checkpoint_file = work_dir / "_checkpoint.jsonl" + lines = [] + for r in records: + lines.append( + json.dumps({"record_id": r.id, "status": "passed", "features": [{"score": 0.9}]}) + ) + checkpoint_file.write_text("\n".join(lines) + "\n") + + runner = AsyncMock() + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_not_called() + + +class TestHookServiceSorting: + """T034-T035: Records sorted by size_hint_mb ascending.""" + + @pytest.mark.asyncio + async def test_records_sorted_by_file_size(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + # Create records with different sizes — large first to test reordering + records = [ + HookRecord(id="large", metadata={}, size_hint_mb=100.0), + HookRecord(id="small", metadata={}, size_hint_mb=1.0), + HookRecord(id="medium", metadata={}, size_hint_mb=50.0), + ] + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + captured_order: list[str] = [] + + async def mock_run(h, inputs, wd): + for r in inputs.records: + captured_order.append(r.id) + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert captured_order == ["small", "medium", "large"] + + @pytest.mark.asyncio + async def test_sorting_skipped_when_no_sizes(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + # All records have default size_hint_mb=0 — original order preserved + records = [ + HookRecord(id="a", metadata={}), + HookRecord(id="b", metadata={}), + HookRecord(id="c", metadata={}), + ] + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + import json + + captured_order: list[str] = [] + + async def mock_run(h, inputs, wd): + for r in inputs.records: + captured_order.append(r.id) + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner = AsyncMock() + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + await service.run_hook(hook, _inputs(records), work_dir) + + assert captured_order == ["a", "b", "c"] + + +class TestHookServiceCorruptedCheckpoint: + """T024: Corrupted checkpoint treated as empty — all records reprocessed.""" + + @pytest.mark.asyncio + async def test_corrupted_checkpoint_treated_as_empty(self, tmp_path: Path): + from osa.domain.validation.service.hook import HookService + + hook = _make_hook() + records = _make_records(2) + work_dir = tmp_path / "hook_out" + work_dir.mkdir() + output_dir = work_dir / "output" + output_dir.mkdir(parents=True) + + # Write corrupted checkpoint + checkpoint_file = work_dir / "_checkpoint.jsonl" + checkpoint_file.write_text("NOT VALID JSON\n{also broken\n") + + import json + + runner = AsyncMock() + + async def mock_run(h, inputs, wd): + # Should receive ALL records since checkpoint is corrupted + assert len(inputs.records) == 2 + features_file = output_dir / "features.jsonl" + with features_file.open("w") as f: + for r in inputs.records: + f.write(json.dumps({"id": r.id, "features": [{"score": 0.9}]}) + "\n") + return _passed_result() + + runner.run.side_effect = mock_run + storage = FakeHookStorage() + + service = HookService(hook_runner=runner, hook_storage=storage) + result = await service.run_hook(hook, _inputs(records), work_dir) + + assert result.status == HookStatus.PASSED + runner.run.assert_called_once() diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index 55abf33..cd2ffd7 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -196,3 +196,56 @@ async def run_hook(hook, inputs, output_dir): assert call_order == ["hook_a", "hook_b"] assert len(results) == 2 + + @pytest.mark.asyncio + async def test_validation_service_halts_on_oom(self): + """REGRESSION: OOM with exhausted retries should halt pipeline as FAILED.""" + hook_runner = AsyncMock() + hook_runner.run.return_value = _make_hook_result(status=HookStatus.OOM) + service = _make_service(hook_runner=hook_runner) + run = await service.create_run(inputs=_make_inputs()) + + run, results = await service.run_hooks( + run=run, + deposition_srn=_make_dep_srn(), + inputs=_make_inputs(), + hooks=[_make_hook_definition()], + ) + + assert run.status == RunStatus.FAILED + + @pytest.mark.asyncio + async def test_validation_service_retries_on_oom(self): + """OOM should be retried via HookService; PASSED on retry → COMPLETED.""" + call_count = 0 + + async def run_hook(hook, inputs, output_dir): + nonlocal call_count + call_count += 1 + if call_count == 1: + return HookResult( + hook_name=hook.name, + status=HookStatus.OOM, + error_message="OOM", + duration_seconds=30.0, + ) + return HookResult( + hook_name=hook.name, + status=HookStatus.PASSED, + duration_seconds=5.0, + ) + + hook_runner = AsyncMock() + hook_runner.run.side_effect = run_hook + service = _make_service(hook_runner=hook_runner) + run = await service.create_run(inputs=_make_inputs()) + + run, results = await service.run_hooks( + run=run, + deposition_srn=_make_dep_srn(), + inputs=_make_inputs(), + hooks=[_make_hook_definition()], + ) + + assert run.status == RunStatus.COMPLETED + assert call_count == 2 diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 4802709..9e63c87 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -585,7 +585,7 @@ async def test_oom_exit_137(self, tmp_path: Path): work_dir, ) - assert result.status == HookStatus.FAILED + assert result.status == HookStatus.OOM assert "oom" in result.error_message.lower() @pytest.mark.asyncio diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index 3220f94..35352bb 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -239,7 +239,7 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): assert "exit" in (result.error_message or "").lower() @pytest.mark.asyncio - async def test_oom_killed_returns_failed(self, tmp_path: Path): + async def test_oom_killed_returns_oom(self, tmp_path: Path): docker = AsyncMock() container = AsyncMock() docker.containers.create.return_value = container @@ -258,7 +258,7 @@ async def test_oom_killed_returns_failed(self, tmp_path: Path): result = await runner.run(hook, inputs, output_dir) - assert result.status == HookStatus.FAILED + assert result.status == HookStatus.OOM assert "oom" in (result.error_message or "").lower() @pytest.mark.asyncio From 0d63431e53c5679c9e9d4589198bbbe79e3b0a02 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 3 Apr 2026 15:29:28 +0100 Subject: [PATCH 9/9] feat(image): add PR-specific Docker image tags for pull request builds fix(ingest): add warning log for redelivered IngestStarted events when limit already met --- .github/workflows/ci.yml | 8 ++++++-- .github/workflows/image.yml | 3 +++ server/osa/domain/ingest/handler/run_ingester.py | 7 +++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 72f8391..a27823c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -203,11 +203,15 @@ jobs: OSA_AUTH__JWT__SECRET: test-secret-for-integration-tests-minimum-32-chars TEST: "1" - # Build & push Docker image (only on main, gated on all server checks) + # Build & push Docker image (main + PRs onto main, gated on all server checks) image: name: Server - Image needs: [changes, server-lint, server-typecheck, server-test, server-contract, server-integration] - if: github.ref == 'refs/heads/main' && needs.changes.outputs.server == 'true' + if: >- + needs.changes.outputs.server == 'true' && ( + github.ref == 'refs/heads/main' || + (github.event_name == 'pull_request' && github.base_ref == 'main') + ) uses: ./.github/workflows/image.yml permissions: contents: read diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index 37efb00..e965840 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -35,6 +35,9 @@ jobs: if [[ "${{ github.event_name }}" == "release" ]]; then TAGS="${TAGS},${{ env.IMAGE }}:${{ github.event.release.tag_name }}" fi + if [[ -n "${{ github.event.pull_request.number }}" ]]; then + TAGS="${TAGS},${{ env.IMAGE }}:pr-${{ github.event.pull_request.number }}" + fi echo "tags=${TAGS}" >> "$GITHUB_OUTPUT" - uses: docker/build-push-action@v6 diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 00eb8c5..4952eda 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -63,8 +63,11 @@ async def handle(self, event: IngestStarted) -> None: ingested_so_far = ingest_run.batches_ingested * ingest_run.batch_size remaining = ingest_run.limit - ingested_so_far if remaining <= 0: - await self.ingest_repo.increment_batches_ingested( - event.ingest_run_srn, set_ingestion_finished=True + log.warn( + "Ignoring redelivered IngestStarted — limit already met (batches_ingested={batches_ingested}, limit={limit})", + batches_ingested=ingest_run.batches_ingested, + limit=ingest_run.limit, + ingest_run_srn=event.ingest_run_srn, ) return effective_batch_limit = min(ingest_run.batch_size, remaining)