diff --git a/server/Justfile b/server/Justfile index eb93bc9..6f96d22 100644 --- a/server/Justfile +++ b/server/Justfile @@ -37,17 +37,17 @@ test-cov: # Run linter and type checker lint: @just fix - uv run ruff check osa + uv run ruff check osa tests uv run ty check osa # Fix formatting and lint issues fix: - uv run ruff format osa - uv run ruff check --fix osa + uv run ruff format osa tests + uv run ruff check --fix osa tests # Format code format: - uv run ruff format osa + uv run ruff format osa tests # === Database === diff --git a/server/migrations/versions/add_deliver_after.py b/server/migrations/versions/add_deliver_after.py new file mode 100644 index 0000000..c7a7b6b --- /dev/null +++ b/server/migrations/versions/add_deliver_after.py @@ -0,0 +1,45 @@ +"""add_deliver_after_and_batches_failed + +Add deliver_after column to deliveries table for explicit backoff scheduling. +Add batches_failed column to ingest_runs table for batch failure accounting. + +Revision ID: add_deliver_after +Revises: add_ingest_runs +Create Date: 2026-04-04 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_deliver_after" +down_revision: Union[str, Sequence[str], None] = "add_ingest_runs" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "deliveries", + sa.Column("deliver_after", sa.DateTime(timezone=True), nullable=True), + ) + op.create_index( + "idx_deliveries_deliver_after", + "deliveries", + ["deliver_after"], + postgresql_where=sa.text("status = 'pending'"), + ) + + op.add_column( + "ingest_runs", + sa.Column("batches_failed", sa.Integer, nullable=False, server_default=sa.text("0")), + ) + + +def downgrade() -> None: + op.drop_column("ingest_runs", "batches_failed") + op.drop_index("idx_deliveries_deliver_after", table_name="deliveries") + op.drop_column("deliveries", "deliver_after") diff --git a/server/migrations/versions/add_ingest_runs.py b/server/migrations/versions/add_ingest_runs.py index 6a23b7c..d4dedf5 100644 --- a/server/migrations/versions/add_ingest_runs.py +++ b/server/migrations/versions/add_ingest_runs.py @@ -23,11 +23,10 @@ def upgrade() -> None: op.create_table( "ingest_runs", - sa.Column("srn", sa.String(), primary_key=True), + sa.Column("id", sa.String(), primary_key=True), sa.Column( "convention_srn", sa.String(), - sa.ForeignKey("conventions.srn"), nullable=False, ), sa.Column( diff --git a/server/osa/domain/feature/handler/insert_batch_features.py b/server/osa/domain/feature/handler/insert_batch_features.py index cf7f366..39896d7 100644 --- a/server/osa/domain/feature/handler/insert_batch_features.py +++ b/server/osa/domain/feature/handler/insert_batch_features.py @@ -26,9 +26,7 @@ async def handle(self, event: IngestBatchPublished) -> None: if not event.expected_features or not event.published_srns: return - batch_output_dir = str( - self.layout.ingest_batch_dir(event.ingest_run_srn, event.batch_index) - ) + batch_output_dir = str(self.layout.ingest_batch_dir(event.ingest_run_id, event.batch_index)) total_inserted = 0 skipped_dupes = 0 @@ -59,7 +57,7 @@ async def handle(self, event: IngestBatchPublished) -> None: ) total_inserted += count - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + short_id = event.ingest_run_id[:8] dupe_msg = f", {skipped_dupes} duplicates skipped" if skipped_dupes else "" log.info( "[{short_id}] batch {batch_index}: inserted {total_inserted} feature rows ({hook_count} hooks{dupe_msg})", @@ -68,5 +66,5 @@ async def handle(self, event: IngestBatchPublished) -> None: total_inserted=total_inserted, hook_count=len(event.expected_features), dupe_msg=dupe_msg, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) diff --git a/server/osa/domain/ingest/command/start_ingest.py b/server/osa/domain/ingest/command/start_ingest.py index 1b65bbd..ddc9716 100644 --- a/server/osa/domain/ingest/command/start_ingest.py +++ b/server/osa/domain/ingest/command/start_ingest.py @@ -34,13 +34,19 @@ class StartIngestHandler(CommandHandler[StartIngest, IngestRunCreated]): service: IngestService async def run(self, cmd: StartIngest) -> IngestRunCreated: + from osa.domain.shared.model.srn import Domain + ingest_run = await self.service.start_ingest( convention_srn=cmd.convention_srn, batch_size=cmd.batch_size, limit=cmd.limit, ) + + node_domain: Domain = self.service.node_domain + srn = f"urn:osa:{node_domain.root}:ing:{ingest_run.id}" + return IngestRunCreated( - srn=ingest_run.srn, + srn=srn, convention_srn=ingest_run.convention_srn, status=ingest_run.status, started_at=ingest_run.started_at.isoformat(), diff --git a/server/osa/domain/ingest/event/__init__.py b/server/osa/domain/ingest/event/__init__.py index c8b3101..b10432b 100644 --- a/server/osa/domain/ingest/event/__init__.py +++ b/server/osa/domain/ingest/event/__init__.py @@ -4,12 +4,14 @@ HookBatchCompleted, IngestBatchPublished, IngestCompleted, - IngestStarted, + IngestRunStarted, IngesterBatchReady, + NextBatchRequested, ) __all__ = [ - "IngestStarted", + "IngestRunStarted", + "NextBatchRequested", "IngesterBatchReady", "HookBatchCompleted", "IngestBatchPublished", diff --git a/server/osa/domain/ingest/event/events.py b/server/osa/domain/ingest/event/events.py index ee0eaab..49e74f5 100644 --- a/server/osa/domain/ingest/event/events.py +++ b/server/osa/domain/ingest/event/events.py @@ -1,13 +1,27 @@ """Ingest domain events — payloads carry path references, not inline data (AD-1).""" +from osa.domain.ingest.model.ingest_run import IngestRunId from osa.domain.shared.event import Event, EventId -class IngestStarted(Event): - """Emitted when an ingest run is created. Triggers first ingester pull.""" +class IngestRunStarted(Event): + """Emitted once when an ingest run is created. Observability/audit only.""" id: EventId - ingest_run_srn: str + ingest_run_id: IngestRunId + convention_srn: str + batch_size: int + + +class NextBatchRequested(Event): + """Emitted to trigger the next ingester batch pull. + + Emitted by StartIngest (first batch) and by RunIngester (continuation). + RunIngester is the only handler that listens to this event. + """ + + id: EventId + ingest_run_id: IngestRunId convention_srn: str batch_size: int @@ -15,11 +29,11 @@ class IngestStarted(Event): class IngesterBatchReady(Event): """Emitted when an ingester container produces a batch of records. - Batch data is on disk at the path derived from {ingest_run_srn, batch_index}. + Batch data is on disk at the path derived from {ingest_run_id, batch_index}. """ id: EventId - ingest_run_srn: str + ingest_run_id: IngestRunId batch_index: int has_more: bool @@ -31,7 +45,7 @@ class HookBatchCompleted(Event): """ id: EventId - ingest_run_srn: str + ingest_run_id: IngestRunId batch_index: int @@ -43,7 +57,7 @@ class IngestBatchPublished(Event): """ id: EventId - ingest_run_srn: str + ingest_run_id: IngestRunId convention_srn: str batch_index: int published_srns: list[str] @@ -56,5 +70,5 @@ class IngestCompleted(Event): """Emitted when all batches are processed and the ingest run is complete.""" id: EventId - ingest_run_srn: str + ingest_run_id: IngestRunId total_published: int diff --git a/server/osa/domain/ingest/handler/publish_batch.py b/server/osa/domain/ingest/handler/publish_batch.py index 40b94d7..e179bd5 100644 --- a/server/osa/domain/ingest/handler/publish_batch.py +++ b/server/osa/domain/ingest/handler/publish_batch.py @@ -1,6 +1,5 @@ """PublishBatch — reads hook outputs, bulk-publishes passing records.""" -from datetime import UTC, datetime from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService @@ -8,12 +7,11 @@ from osa.domain.ingest.event.events import ( HookBatchCompleted, IngestBatchPublished, - IngestCompleted, ) -from osa.domain.ingest.model.ingest_run import IngestStatus from osa.domain.ingest.model.ingester_record import IngesterRecord from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.ingest.port.storage import IngestStoragePort +from osa.domain.ingest.service.ingest import IngestService from osa.domain.record.model.draft import RecordDraft from osa.domain.record.service import RecordService from osa.domain.shared.error import NotFoundError @@ -35,24 +33,23 @@ class PublishBatch(EventHandler[HookBatchCompleted]): feature_storage: FeatureStoragePort outbox: Outbox ingest_storage: IngestStoragePort + ingest_service: IngestService async def handle(self, event: HookBatchCompleted) -> None: - ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + ingest_run = await self.ingest_repo.get(event.ingest_run_id) if ingest_run is None: - raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + raise NotFoundError(f"Ingest run not found: {event.ingest_run_id}") convention = await self.convention_service.get_convention( ConventionSRN.parse(ingest_run.convention_srn) ) # Read ingester records via storage port (filesystem or S3) - raw_records = await self.ingest_storage.read_records( - event.ingest_run_srn, event.batch_index - ) + raw_records = await self.ingest_storage.read_records(event.ingest_run_id, event.batch_index) ingester_records = IngesterRecord.from_dicts(raw_records) # batch_dir used as locator for hook outcome reads - batch_dir = str(self.ingest_storage.batch_dir(event.ingest_run_srn, event.batch_index)) + batch_dir = str(self.ingest_storage.batch_dir(event.ingest_run_id, event.batch_index)) # Read hook outcomes for all hooks expected_features = [h.name for h in convention.hooks] @@ -67,7 +64,7 @@ async def handle(self, event: HookBatchCompleted) -> None: ) # Log outcome breakdown per hook - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + short_id = event.ingest_run_id[:8] total = len(ingester_records) for hook_name in expected_features: outcomes = await self.feature_storage.read_batch_outcomes(str(batch_dir), hook_name) @@ -88,7 +85,7 @@ async def handle(self, event: HookBatchCompleted) -> None: rejected=rejected, errored=errored, missing=missing, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) published_count = 0 @@ -100,7 +97,7 @@ async def handle(self, event: HookBatchCompleted) -> None: RecordDraft( source=IngestSource( id=f"{ingest_run.convention_srn}:{record.source_id}", - ingest_run_srn=ingest_run.srn, + ingest_run_id=ingest_run.id, upstream_source=record.source_id, ), metadata=record.metadata, @@ -133,7 +130,7 @@ async def handle(self, event: HookBatchCompleted) -> None: published=published_count, passed=len(passed_records), duplicates=len(drafts) - published_count, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) # Emit IngestBatchPublished for feature insertion @@ -141,7 +138,7 @@ async def handle(self, event: HookBatchCompleted) -> None: await self.outbox.append( IngestBatchPublished( id=EventId(uuid4()), - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, convention_srn=ingest_run.convention_srn, batch_index=event.batch_index, published_srns=published_srns, @@ -151,31 +148,17 @@ async def handle(self, event: HookBatchCompleted) -> None: ) ) - # Update counters atomically - updated = await self.ingest_repo.increment_completed( - event.ingest_run_srn, - published_count=published_count, - ) - - # Check completion condition - if updated.is_complete and updated.status == IngestStatus.RUNNING: - updated.check_completion(datetime.now(UTC)) - await self.ingest_repo.save(updated) + # Update counters and check completion via service + await self.ingest_service.complete_batch(event.ingest_run_id, published_count) - await self.outbox.append( - IngestCompleted( - id=EventId(uuid4()), - ingest_run_srn=event.ingest_run_srn, - total_published=updated.published_count, - ) - ) - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] - log.info( - "[{short_id}] COMPLETE: {total_published} records published", - short_id=short_id, - total_published=updated.published_count, - ingest_run_srn=event.ingest_run_srn, - ) + async def on_exhausted(self, event: HookBatchCompleted) -> None: + """Called when publish retries are exhausted — account for the failed batch.""" + log.error( + "batch {batch_index} publish retries exhausted", + batch_index=event.batch_index, + ingest_run_id=event.ingest_run_id, + ) + await self.ingest_service.fail_batch(event.ingest_run_id) async def _get_passed_records( diff --git a/server/osa/domain/ingest/handler/run_hooks.py b/server/osa/domain/ingest/handler/run_hooks.py index 63eb205..4794092 100644 --- a/server/osa/domain/ingest/handler/run_hooks.py +++ b/server/osa/domain/ingest/handler/run_hooks.py @@ -1,5 +1,7 @@ """RunHooks — runs hook containers on an ingester batch.""" +from osa.domain.validation.model import HookResult + from pathlib import Path from uuid import uuid4 @@ -8,7 +10,8 @@ from osa.domain.ingest.model.ingester_record import IngesterRecord from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.ingest.port.storage import IngestStoragePort -from osa.domain.shared.error import NotFoundError +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.error import NotFoundError, OOMError, PermanentError from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox @@ -24,37 +27,37 @@ class RunHooks(EventHandler[IngesterBatchReady]): """Runs hook containers on an ingester batch and emits HookBatchCompleted.""" __claim_timeout__ = 3600.0 + __max_retries__ = 100 ingest_repo: IngestRunRepository + ingest_service: IngestService convention_service: ConventionService hook_service: HookService outbox: Outbox ingest_storage: IngestStoragePort async def handle(self, event: IngesterBatchReady) -> None: - ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + ingest_run = await self.ingest_repo.get(event.ingest_run_id) if ingest_run is None: - raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + raise NotFoundError(f"Ingest run not found: {event.ingest_run_id}") convention = await self.convention_service.get_convention( ConventionSRN.parse(ingest_run.convention_srn) ) # Read records via storage port (filesystem or S3) - raw_records = await self.ingest_storage.read_records( - event.ingest_run_srn, event.batch_index - ) + raw_records = await self.ingest_storage.read_records(event.ingest_run_id, event.batch_index) records = IngesterRecord.from_dicts(raw_records) if not records: log.warn( "ingest batch {batch_index}: no records to process", batch_index=event.batch_index, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) # Build files_dirs from ingester files (Path locators for runner volume mounts) - files_base = self.ingest_storage.batch_files_dir(event.ingest_run_srn, event.batch_index) + files_base = self.ingest_storage.batch_files_dir(event.ingest_run_id, event.batch_index) files_dirs: dict[str, Path] = {} for record in records: if record.files: @@ -70,7 +73,7 @@ async def handle(self, event: IngesterBatchReady) -> None: ) for r in records ], - run_id=f"{event.ingest_run_srn}_batch{event.batch_index}", + run_id=f"{event.ingest_run_id}_b{event.batch_index}", files_dirs=files_dirs, ) @@ -78,17 +81,41 @@ async def handle(self, event: IngesterBatchReady) -> None: work_dirs: dict[str, Path] = {} for hook in convention.hooks: work_dirs[hook.name] = self.ingest_storage.hook_work_dir( - event.ingest_run_srn, event.batch_index, hook.name + event.ingest_run_id, event.batch_index, hook.name ) # Run all hooks via HookService - results = await self.hook_service.run_hooks_for_batch( - hooks=convention.hooks, - inputs=inputs, - work_dirs=work_dirs, - ) + results: list[HookResult] = [] + try: + results = await self.hook_service.run_hooks_for_batch( + hooks=convention.hooks, + inputs=inputs, + work_dirs=work_dirs, + ) + except OOMError as e: + # OOM exhaustion after retries — HookService already wrote outcomes + # (passed + errored) to disk. Fall through to emit HookBatchCompleted + # so PublishBatch can publish the records that DID pass. + log.warn( + "[{short_id}] batch {batch_index} OOM exhausted, publishing partial results: {error}", + short_id=event.ingest_run_id[:8], + batch_index=event.batch_index, + error=str(e), + ingest_run_id=event.ingest_run_id, + ) + except PermanentError as e: + log.error( + "[{short_id}] batch {batch_index} permanently failed: {error}", + short_id=event.ingest_run_id[:8], + batch_index=event.batch_index, + error=str(e), + container_logs=e.container_logs or "", + ingest_run_id=event.ingest_run_id, + ) + await self._fail_batch(event) + return - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + short_id = event.ingest_run_id[:8] for result in results: log.info( "[{short_id}] batch {batch_index} hook={hook_name}: {status} in {duration:.1f}s", @@ -97,14 +124,27 @@ async def handle(self, event: IngesterBatchReady) -> None: hook_name=result.hook_name, status=result.status.value, duration=result.duration_seconds, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) # Emit HookBatchCompleted await self.outbox.append( HookBatchCompleted( id=EventId(uuid4()), - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, batch_index=event.batch_index, ) ) + + async def on_exhausted(self, event: IngesterBatchReady) -> None: + """Called when transient retries are exhausted — account for the failed batch.""" + log.error( + "batch {batch_index} retries exhausted", + batch_index=event.batch_index, + ingest_run_id=event.ingest_run_id, + ) + await self._fail_batch(event) + + async def _fail_batch(self, event: IngesterBatchReady) -> None: + """Account for a permanently failed batch.""" + await self.ingest_service.fail_batch(event.ingest_run_id) diff --git a/server/osa/domain/ingest/handler/run_ingester.py b/server/osa/domain/ingest/handler/run_ingester.py index 4c3130e..9c9d369 100644 --- a/server/osa/domain/ingest/handler/run_ingester.py +++ b/server/osa/domain/ingest/handler/run_ingester.py @@ -1,37 +1,62 @@ -"""RunIngester — runs ingester container on IngestStarted or continuation.""" +"""RunIngester — runs ingester container on NextBatchRequested.""" +from datetime import UTC, datetime, timedelta from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService -from osa.domain.ingest.event.events import IngestStarted, IngesterBatchReady +from osa.domain.ingest.event.events import IngesterBatchReady, NextBatchRequested from osa.domain.ingest.model.ingest_run import IngestStatus from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.ingest.port.storage import IngestStoragePort -from osa.domain.shared.error import NotFoundError +from osa.domain.ingest.service.ingest import IngestService +from osa.domain.shared.error import NotFoundError, PermanentError from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterRunner from osa.infrastructure.logging import get_logger +BACKPRESSURE_DELAY = timedelta(seconds=60) + log = get_logger(__name__) -class RunIngester(EventHandler[IngestStarted]): +class RunIngester(EventHandler[NextBatchRequested]): """Runs ingester container and emits IngesterBatchReady per batch.""" __claim_timeout__ = 3600.0 + __max_retries__ = 20 ingest_repo: IngestRunRepository + ingest_service: IngestService convention_service: ConventionService ingester_runner: IngesterRunner outbox: Outbox ingest_storage: IngestStoragePort - async def handle(self, event: IngestStarted) -> None: - ingest_run = await self.ingest_repo.get(event.ingest_run_srn) + async def handle(self, event: NextBatchRequested) -> None: + ingest_run = await self.ingest_repo.get(event.ingest_run_id) if ingest_run is None: - raise NotFoundError(f"Ingest run not found: {event.ingest_run_srn}") + raise NotFoundError(f"Ingest run not found: {event.ingest_run_id}") + + # Backpressure: don't ingest if the cluster can't schedule more Jobs + if not await self.ingester_runner.has_capacity(): + log.info( + "[{short_id}] backpressure: cluster has pending Jobs, deferring next pull +{delay}s", + short_id=event.ingest_run_id[:8], + delay=int(BACKPRESSURE_DELAY.total_seconds()), + ingest_run_id=event.ingest_run_id, + ) + await self.outbox.append( + NextBatchRequested( + id=EventId(uuid4()), + ingest_run_id=event.ingest_run_id, + convention_srn=event.convention_srn, + batch_size=event.batch_size, + ), + deliver_after=datetime.now(UTC) + BACKPRESSURE_DELAY, + ) + return if ingest_run.status == IngestStatus.PENDING: ingest_run.mark_running() @@ -45,7 +70,7 @@ async def handle(self, event: IngestStarted) -> None: batch_index = ingest_run.batches_ingested - session = await self.ingest_storage.read_session(event.ingest_run_srn) + session = await self.ingest_storage.read_session(event.ingest_run_id) effective_batch_limit = ingest_run.batch_size if ingest_run.limit is not None: @@ -53,13 +78,13 @@ async def handle(self, event: IngestStarted) -> None: remaining = ingest_run.limit - ingested_so_far if remaining <= 0: log.warn( - "Ignoring redelivered IngestStarted — limit already met (batches_ingested={batches_ingested}, limit={limit})", + "Ignoring redelivered NextBatchRequested — limit already met (batches_ingested={batches_ingested}, limit={limit})", batches_ingested=ingest_run.batches_ingested, limit=ingest_run.limit, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) await self.ingest_repo.increment_batches_ingested( - event.ingest_run_srn, + event.ingest_run_id, set_ingestion_finished=True, ) return @@ -67,24 +92,37 @@ async def handle(self, event: IngestStarted) -> None: inputs = IngesterInputs( convention_srn=convention.srn, + ingest_run_id=event.ingest_run_id, + batch_index=batch_index, config=convention.ingester.config, limit=effective_batch_limit, session=session, ) - work_dir = self.ingest_storage.batch_work_dir(event.ingest_run_srn, batch_index) - files_dir = self.ingest_storage.batch_files_dir(event.ingest_run_srn, batch_index) - - output = await self.ingester_runner.run( - ingester=convention.ingester, - inputs=inputs, - files_dir=files_dir, - work_dir=work_dir, - ) + work_dir = self.ingest_storage.batch_work_dir(event.ingest_run_id, batch_index) + files_dir = self.ingest_storage.batch_files_dir(event.ingest_run_id, batch_index) + + try: + output = await self.ingester_runner.run( + ingester=convention.ingester, + inputs=inputs, + files_dir=files_dir, + work_dir=work_dir, + ) + except PermanentError as e: + log.error( + "[{short_id}] ingester permanently failed: {error}", + short_id=event.ingest_run_id[:8], + error=str(e), + container_logs=e.container_logs or "", + ingest_run_id=event.ingest_run_id, + ) + await self._fail_ingestion(event) + return - await self.ingest_storage.write_records(event.ingest_run_srn, batch_index, output.records) + await self.ingest_storage.write_records(event.ingest_run_id, batch_index, output.records) if output.session: - await self.ingest_storage.write_session(event.ingest_run_srn, output.session) + await self.ingest_storage.write_session(event.ingest_run_id, output.session) has_more = output.session is not None and len(output.records) > 0 @@ -94,35 +132,47 @@ async def handle(self, event: IngestStarted) -> None: has_more = False await self.ingest_repo.increment_batches_ingested( - event.ingest_run_srn, + event.ingest_run_id, set_ingestion_finished=not has_more, ) await self.outbox.append( IngesterBatchReady( id=EventId(uuid4()), - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, batch_index=batch_index, has_more=has_more, ) ) - short_id = event.ingest_run_srn.rsplit(":", 1)[-1][:8] + short_id = event.ingest_run_id[:8] log.info( "[{short_id}] batch {batch_index}: pulled {record_count} records (has_more={has_more})", short_id=short_id, batch_index=batch_index, record_count=len(output.records), has_more=has_more, - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, ) if has_more: await self.outbox.append( - IngestStarted( + NextBatchRequested( id=EventId(uuid4()), - ingest_run_srn=event.ingest_run_srn, + ingest_run_id=event.ingest_run_id, convention_srn=event.convention_srn, batch_size=ingest_run.batch_size, ) ) + + async def on_exhausted(self, event: NextBatchRequested) -> None: + """Transient retries exhausted — stop ingestion and check completion.""" + log.error( + "ingester retries exhausted", + ingest_run_id=event.ingest_run_id, + ) + await self._fail_ingestion(event) + + async def _fail_ingestion(self, event: NextBatchRequested) -> None: + """Account for a permanently failed ingester pull.""" + await self.ingest_service.fail_ingestion(event.ingest_run_id) diff --git a/server/osa/domain/ingest/model/ingest_run.py b/server/osa/domain/ingest/model/ingest_run.py index c95f604..53a899b 100644 --- a/server/osa/domain/ingest/model/ingest_run.py +++ b/server/osa/domain/ingest/model/ingest_run.py @@ -2,10 +2,13 @@ from datetime import datetime from enum import StrEnum +from typing import NewType from osa.domain.shared.error import InvalidStateError from osa.domain.shared.model.aggregate import Aggregate +IngestRunId = NewType("IngestRunId", str) + class IngestStatus(StrEnum): PENDING = "pending" @@ -29,13 +32,14 @@ class IngestRun(Aggregate): Counter updates use atomic SQL increments in the repository. """ - srn: str + id: IngestRunId convention_srn: str status: IngestStatus = IngestStatus.PENDING ingestion_finished: bool = False batches_ingested: int = 0 batches_completed: int = 0 published_count: int = 0 + batches_failed: int = 0 batch_size: int = 1000 limit: int | None = None # Max total records (None = unlimited) started_at: datetime @@ -71,8 +75,11 @@ def record_batch_completed(self, published_count: int) -> None: @property def is_complete(self) -> bool: - """Check the completion condition: all sourced batches are completed.""" - return self.ingestion_finished and self.batches_ingested == self.batches_completed + """Check the completion condition: all sourced batches are accounted for.""" + return ( + self.ingestion_finished + and (self.batches_completed + self.batches_failed) >= self.batches_ingested + ) def check_completion(self, completed_at: datetime) -> bool: """Check completion condition and transition if met. diff --git a/server/osa/domain/ingest/port/repository.py b/server/osa/domain/ingest/port/repository.py index f5839d6..3b6a376 100644 --- a/server/osa/domain/ingest/port/repository.py +++ b/server/osa/domain/ingest/port/repository.py @@ -3,7 +3,7 @@ from abc import abstractmethod from typing import Protocol -from osa.domain.ingest.model.ingest_run import IngestRun +from osa.domain.ingest.model.ingest_run import IngestRun, IngestRunId from osa.domain.shared.port import Port @@ -21,8 +21,8 @@ async def save(self, ingest_run: IngestRun) -> None: ... @abstractmethod - async def get(self, srn: str) -> IngestRun | None: - """Get an ingest run by SRN.""" + async def get(self, id: IngestRunId) -> IngestRun | None: + """Get an ingest run by ID.""" ... @abstractmethod @@ -32,7 +32,7 @@ async def get_running_for_convention(self, convention_srn: str) -> IngestRun | N @abstractmethod async def increment_batches_ingested( - self, srn: str, *, set_ingestion_finished: bool = False + self, id: IngestRunId, *, set_ingestion_finished: bool = False ) -> IngestRun: """Atomically increment batches_ingested and optionally set ingestion_finished. @@ -41,7 +41,15 @@ async def increment_batches_ingested( ... @abstractmethod - async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + async def increment_failed(self, id: IngestRunId) -> IngestRun: + """Atomically increment batches_failed. + + Returns the updated IngestRun for completion condition checking. + """ + ... + + @abstractmethod + async def increment_completed(self, id: IngestRunId, published_count: int) -> IngestRun: """Atomically increment batches_completed and published_count. Returns the updated IngestRun with DB-authoritative counter values diff --git a/server/osa/domain/ingest/port/storage.py b/server/osa/domain/ingest/port/storage.py index e66a5d0..3099c46 100644 --- a/server/osa/domain/ingest/port/storage.py +++ b/server/osa/domain/ingest/port/storage.py @@ -20,43 +20,43 @@ class IngestStoragePort(Port, Protocol): """ @abstractmethod - async def read_session(self, ingest_run_srn: str) -> dict[str, Any] | None: + async def read_session(self, ingest_run_id: str) -> dict[str, Any] | None: """Read session state for ingester continuation. Returns None if no session.""" ... @abstractmethod - async def write_session(self, ingest_run_srn: str, session: dict[str, Any]) -> None: + async def write_session(self, ingest_run_id: str, session: dict[str, Any]) -> None: """Persist session state between batches.""" ... @abstractmethod async def write_records( - self, ingest_run_srn: str, batch_index: int, records: list[dict[str, Any]] + self, ingest_run_id: str, batch_index: int, records: list[dict[str, Any]] ) -> None: """Write ingester output records for a batch as JSONL.""" ... @abstractmethod - async def read_records(self, ingest_run_srn: str, batch_index: int) -> list[dict[str, Any]]: + async def read_records(self, ingest_run_id: str, batch_index: int) -> list[dict[str, Any]]: """Read raw ingester output records for a batch.""" ... @abstractmethod - def batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + def batch_dir(self, ingest_run_id: str, batch_index: int) -> Path: """Return the batch-level directory (parent of ingester/ and hooks/).""" ... @abstractmethod - def batch_work_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + def batch_work_dir(self, ingest_run_id: str, batch_index: int) -> Path: """Return the ingester work directory for a batch.""" ... @abstractmethod - def batch_files_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + def batch_files_dir(self, ingest_run_id: str, batch_index: int) -> Path: """Return the files directory for a batch.""" ... @abstractmethod - def hook_work_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: + def hook_work_dir(self, ingest_run_id: str, batch_index: int, hook_name: str) -> Path: """Return the hook output directory for a batch.""" ... diff --git a/server/osa/domain/ingest/service/ingest.py b/server/osa/domain/ingest/service/ingest.py index 35c338a..fd091af 100644 --- a/server/osa/domain/ingest/service/ingest.py +++ b/server/osa/domain/ingest/service/ingest.py @@ -4,8 +4,8 @@ from uuid import uuid4 from osa.domain.deposition.service.convention import ConventionService -from osa.domain.ingest.event.events import IngestStarted -from osa.domain.ingest.model.ingest_run import IngestRun, IngestStatus +from osa.domain.ingest.event.events import IngestCompleted, IngestRunStarted, NextBatchRequested +from osa.domain.ingest.model.ingest_run import IngestRun, IngestRunId, IngestStatus from osa.domain.ingest.port.repository import IngestRunRepository from osa.domain.shared.error import ConflictError, NotFoundError from osa.domain.shared.event import EventId @@ -54,11 +54,11 @@ async def start_ingest( code="ingest_already_running", ) - srn = f"urn:osa:{self.node_domain.root}:ing:{uuid4()}" + run_id = IngestRunId(str(uuid4())) now = datetime.now(UTC) ingest_run = IngestRun( - srn=srn, + id=run_id, convention_srn=convention_srn, status=IngestStatus.PENDING, batch_size=batch_size, @@ -69,14 +69,24 @@ async def start_ingest( await self.ingest_repo.save(ingest_run) await self.outbox.append( - IngestStarted( + IngestRunStarted( id=EventId(uuid4()), - ingest_run_srn=srn, + ingest_run_id=run_id, convention_srn=convention_srn, batch_size=batch_size, ) ) + await self.outbox.append( + NextBatchRequested( + id=EventId(uuid4()), + ingest_run_id=run_id, + convention_srn=convention_srn, + batch_size=batch_size, + ) + ) + + srn = f"urn:osa:{self.node_domain.root}:ing:{run_id}" log.info( "ingest started for {convention_srn}", ingest_run_srn=srn, @@ -85,3 +95,57 @@ async def start_ingest( limit=limit, ) return ingest_run + + async def complete_batch(self, ingest_run_id: IngestRunId, published_count: int) -> None: + """Account for a successfully processed batch. + + Increments batches_completed and published_count atomically, + then checks the completion condition. + """ + ingest_run = await self.ingest_repo.increment_completed( + ingest_run_id, published_count=published_count + ) + await self._check_completion(ingest_run) + + async def fail_batch(self, ingest_run_id: IngestRunId) -> None: + """Account for a batch that permanently failed hook processing. + + Increments batches_failed and completes the run if all batches + are now accounted for (completed + failed >= ingested). + """ + ingest_run = await self.ingest_repo.increment_failed(ingest_run_id) + await self._check_completion(ingest_run) + + async def fail_ingestion(self, ingest_run_id: IngestRunId) -> None: + """Account for a failed ingester pull. + + The batch was never sourced, so we mark ingestion as finished + (no more batches coming) and increment batches_failed. The + completion condition can then fire based on whatever batches + were already sourced. + """ + await self.ingest_repo.increment_batches_ingested( + ingest_run_id, + set_ingestion_finished=True, + ) + ingest_run = await self.ingest_repo.increment_failed(ingest_run_id) + await self._check_completion(ingest_run) + + async def _check_completion(self, ingest_run: IngestRun) -> None: + """Transition to COMPLETED and emit IngestCompleted if all batches are accounted for.""" + if not ingest_run.check_completion(datetime.now(UTC)): + return + await self.ingest_repo.save(ingest_run) + await self.outbox.append( + IngestCompleted( + id=EventId(uuid4()), + ingest_run_id=ingest_run.id, + total_published=ingest_run.published_count, + ) + ) + log.info( + "[{short_id}] COMPLETE: {total_published} records published", + short_id=str(ingest_run.id)[:8], + total_published=ingest_run.published_count, + ingest_run_id=ingest_run.id, + ) diff --git a/server/osa/domain/shared/error.py b/server/osa/domain/shared/error.py index ab50e83..a26c0d6 100644 --- a/server/osa/domain/shared/error.py +++ b/server/osa/domain/shared/error.py @@ -59,6 +59,8 @@ class AuthorizationError(DomainError): class InfrastructureError(OSAError): """Base class for infrastructure/system errors.""" + container_logs: str | None = None + class StorageUnavailableError(InfrastructureError): """Storage backend (database, object store) is unavailable.""" @@ -72,6 +74,18 @@ class ConfigurationError(InfrastructureError): """System misconfiguration detected.""" +class TransientError(InfrastructureError): + """Temporary failure — worker retries with backoff.""" + + +class PermanentError(InfrastructureError): + """Unrecoverable failure — worker gives up.""" + + +class OOMError(PermanentError): + """Container killed by out-of-memory. HookService intercepts for memory retry.""" + + # ============================================================================= # Event Processing Errors # ============================================================================= diff --git a/server/osa/domain/shared/event.py b/server/osa/domain/shared/event.py index 0fc19b3..de5ef7c 100644 --- a/server/osa/domain/shared/event.py +++ b/server/osa/domain/shared/event.py @@ -130,6 +130,7 @@ class Delivery: id: str event: "Event" + retry_count: int = 0 @dataclass(frozen=True) @@ -260,6 +261,16 @@ async def handle_batch(self, events: list[E]) -> None: for event in events: await self.handle(event) + async def on_exhausted(self, event: E) -> None: + """Called when delivery retries are exhausted or failure is permanent. + + Override to perform cleanup or accounting when an event will never + be successfully processed. Default: no-op. + + Args: + event: The event that could not be processed. + """ + # --- Schedule --- diff --git a/server/osa/domain/shared/model/source.py b/server/osa/domain/shared/model/source.py index 9d4b1ef..47de498 100644 --- a/server/osa/domain/shared/model/source.py +++ b/server/osa/domain/shared/model/source.py @@ -55,7 +55,7 @@ class IngestSource(_RecordSourceBase): """Record originated from an automated ingest run.""" type: Literal["ingest"] = "ingest" - ingest_run_srn: str + ingest_run_id: str upstream_source: str diff --git a/server/osa/domain/shared/outbox.py b/server/osa/domain/shared/outbox.py index eb508fc..8659f3c 100644 --- a/server/osa/domain/shared/outbox.py +++ b/server/osa/domain/shared/outbox.py @@ -1,5 +1,6 @@ """Outbox - domain service for reliable event delivery.""" +from datetime import datetime from typing import TypeVar from osa.domain.shared.event import ClaimResult, Event @@ -24,7 +25,7 @@ class Outbox(Service): _repo: EventRepository _registry: SubscriptionRegistry - async def append(self, event: Event) -> None: + async def append(self, event: Event, *, deliver_after: datetime | None = None) -> None: """Add an event to the outbox for delivery. Creates one delivery row per consumer group subscribed to this event type. @@ -32,10 +33,13 @@ async def append(self, event: Event) -> None: Args: event: The event to append. + deliver_after: If set, deliveries won't be claimed until this time. """ event_type_name = type(event).__name__ consumer_groups = self._registry.get(event_type_name, set()) - await self._repo.save_with_deliveries(event, consumer_groups=consumer_groups) + await self._repo.save_with_deliveries( + event, consumer_groups=consumer_groups, deliver_after=deliver_after + ) async def claim( self, @@ -80,6 +84,7 @@ async def mark_failed_with_retry( delivery_id: str, error: str, max_retries: int, + deliver_after: datetime | None = None, ) -> None: """Mark a delivery as failed, with retry logic. @@ -90,8 +95,12 @@ async def mark_failed_with_retry( delivery_id: The delivery row ID. error: Error message. max_retries: Maximum retry attempts before marking as failed. + deliver_after: If set, the delivery won't be eligible for claiming + until this timestamp. Used for transient resource backoff. """ - await self._repo.mark_failed_with_retry(delivery_id, error=error, max_retries=max_retries) + await self._repo.mark_failed_with_retry( + delivery_id, error=error, max_retries=max_retries, deliver_after=deliver_after + ) async def reset_stale_claims(self, timeout_seconds: float) -> int: """Reset deliveries that have been claimed for too long. diff --git a/server/osa/domain/shared/port/event_repository.py b/server/osa/domain/shared/port/event_repository.py index 8a222d1..f963f3c 100644 --- a/server/osa/domain/shared/port/event_repository.py +++ b/server/osa/domain/shared/port/event_repository.py @@ -1,5 +1,6 @@ """EventRepository port - pure CRUD for event persistence.""" +from datetime import datetime from typing import Protocol, TypeVar from osa.domain.shared.event import ClaimResult, Event, EventId @@ -18,6 +19,7 @@ async def save_with_deliveries( self, event: Event, consumer_groups: set[str], + deliver_after: datetime | None = None, ) -> None: """Save event to the append-only log and create delivery rows. @@ -25,6 +27,7 @@ async def save_with_deliveries( event: The event to persist. consumer_groups: Set of consumer group names to create deliveries for. If empty, the event is saved without any delivery rows (audit-only). + deliver_after: If set, deliveries won't be claimed until this time. """ ... @@ -123,6 +126,7 @@ async def mark_failed_with_retry( delivery_id: str, error: str, max_retries: int, + deliver_after: datetime | None = None, ) -> None: """Mark a delivery as failed with retry logic. diff --git a/server/osa/domain/shared/port/ingester_runner.py b/server/osa/domain/shared/port/ingester_runner.py index c4bb7a0..4b51d14 100644 --- a/server/osa/domain/shared/port/ingester_runner.py +++ b/server/osa/domain/shared/port/ingester_runner.py @@ -20,6 +20,8 @@ class IngesterInputs: """Inputs for an ingester container run.""" convention_srn: ConventionSRN + ingest_run_id: str = "" + batch_index: int = 0 config: dict[str, Any] | None = None since: datetime | None = None limit: int | None = None @@ -46,3 +48,19 @@ async def run( files_dir: Path, work_dir: Path, ) -> IngesterOutput: ... + + async def capture_logs(self, run_id: str) -> str: + """Capture recent container logs for a run. + + Returns the last few lines of container/pod output, or empty string + if logs are unavailable. Used for failure diagnostics. + """ + ... + + async def has_capacity(self) -> bool: + """Check whether the cluster can schedule more Jobs. + + Returns False if there are pending (unschedulable) Jobs in the namespace. + Used by the ingester to avoid submitting work that will just timeout. + """ + ... diff --git a/server/osa/domain/validation/model/entity.py b/server/osa/domain/validation/model/entity.py index fdd35ca..9436895 100644 --- a/server/osa/domain/validation/model/entity.py +++ b/server/osa/domain/validation/model/entity.py @@ -24,8 +24,6 @@ def summary(self) -> HookStatus | None: if not self.results: return None statuses = [r.status for r in self.results] - if HookStatus.FAILED in statuses: - return HookStatus.FAILED if HookStatus.REJECTED in statuses: return HookStatus.REJECTED return HookStatus.PASSED diff --git a/server/osa/domain/validation/model/hook_result.py b/server/osa/domain/validation/model/hook_result.py index ba4b959..7841d6f 100644 --- a/server/osa/domain/validation/model/hook_result.py +++ b/server/osa/domain/validation/model/hook_result.py @@ -10,8 +10,6 @@ class HookStatus(StrEnum): PASSED = "passed" REJECTED = "rejected" - FAILED = "failed" - OOM = "oom" class ProgressEntry(ValueObject): @@ -31,8 +29,3 @@ class HookResult(ValueObject): error_message: str | None = None progress: list[ProgressEntry] = Field(default_factory=list) duration_seconds: float - - @property - def oom_killed(self) -> bool: - """Whether this hook was killed by an out-of-memory condition.""" - return self.status == HookStatus.OOM diff --git a/server/osa/domain/validation/port/hook_runner.py b/server/osa/domain/validation/port/hook_runner.py index 55ff484..6ee2c89 100644 --- a/server/osa/domain/validation/port/hook_runner.py +++ b/server/osa/domain/validation/port/hook_runner.py @@ -43,3 +43,12 @@ async def run( ``input/`` is ephemeral (cleaned after run); ``output/`` persists for later reading. """ ... + + @abstractmethod + async def capture_logs(self, run_id: str) -> str: + """Capture recent container logs for a run. + + Returns the last few lines of container/pod output, or empty string + if logs are unavailable. Used for failure diagnostics. + """ + ... diff --git a/server/osa/domain/validation/service/hook.py b/server/osa/domain/validation/service/hook.py index 3c770a4..4718578 100644 --- a/server/osa/domain/validation/service/hook.py +++ b/server/osa/domain/validation/service/hook.py @@ -13,6 +13,7 @@ from collections.abc import Iterable from pathlib import Path +from osa.domain.shared.error import OOMError from osa.domain.shared.model.hook import HookDefinition from osa.domain.shared.service import Service from osa.domain.validation.model.batch_outcome import ( @@ -81,16 +82,15 @@ async def run_hook( config=inputs.config, ) - result = await self.hook_runner.run(current_hook, attempt_inputs, work_dir) - total_duration += result.duration_seconds - - # Read any output written by this attempt - new_outcomes = _read_output_dir(work_dir) - for rid, outcome in new_outcomes.items(): - if rid not in outcomes: - outcomes[rid] = outcome + try: + result = await self.hook_runner.run(current_hook, attempt_inputs, work_dir) + except OOMError: + # Read any partial output written before OOM + new_outcomes = _read_output_dir(work_dir) + for rid, outcome in new_outcomes.items(): + if rid not in outcomes: + outcomes[rid] = outcome - if result.oom_killed: # Checkpoint what we have so far await self.hook_storage.write_checkpoint(work_dir, outcomes) @@ -118,17 +118,19 @@ async def run_hook( error=f"OOM after {MAX_OOM_RETRIES} retries (last limit: {current_hook.runtime.limits.memory})", ) await self.hook_storage.write_batch_outcomes(work_dir, outcomes) - return HookResult( - hook_name=hook.name, - status=HookStatus.OOM, - error_message=f"OOM exhausted after {MAX_OOM_RETRIES} retries", - duration_seconds=total_duration, - ) - elif result.status == HookStatus.FAILED: - # Non-OOM failure — no retry - await self.hook_storage.write_batch_outcomes(work_dir, outcomes) - return result - elif result.status == HookStatus.REJECTED: + raise + # Non-OOM exceptions (TransientError, PermanentError, etc.) + # propagate uncaught to the worker layer + + total_duration += result.duration_seconds + + # Read any output written by this attempt + new_outcomes = _read_output_dir(work_dir) + for rid, outcome in new_outcomes.items(): + if rid not in outcomes: + outcomes[rid] = outcome + + if result.status == HookStatus.REJECTED: # Rejection — no retry, propagate status await self.hook_storage.write_batch_outcomes(work_dir, outcomes) return HookResult( diff --git a/server/osa/domain/validation/service/validation.py b/server/osa/domain/validation/service/validation.py index a2dcdee..a675f5b 100644 --- a/server/osa/domain/validation/service/validation.py +++ b/server/osa/domain/validation/service/validation.py @@ -85,12 +85,13 @@ async def run_hooks( for hook in hooks: work_dir = self.hook_storage.get_hook_output_dir(deposition_srn, hook.name) - result = await hook_service.run_hook(hook, inputs, work_dir) - hook_results.append(result) - - if result.status in (HookStatus.FAILED, HookStatus.OOM): + try: + result = await hook_service.run_hook(hook, inputs, work_dir) + except Exception: overall_status = RunStatus.FAILED break + hook_results.append(result) + if result.status == HookStatus.REJECTED: overall_status = RunStatus.REJECTED break diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 7721c2a..202abf2 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -1,9 +1,9 @@ """Worker and WorkerPool for pull-based event processing.""" import asyncio -import logging from contextlib import AsyncExitStack from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta from typing import TYPE_CHECKING, Any, NewType if TYPE_CHECKING: @@ -13,7 +13,7 @@ from apscheduler.triggers.cron import CronTrigger from dishka import AsyncContainer from osa.domain.auth.model.identity import Identity, System -from osa.domain.shared.error import SkippedEvents +from osa.domain.shared.error import PermanentError, SkippedEvents, TransientError from osa.domain.shared.event import ( EventHandler, Schedule, @@ -22,9 +22,10 @@ WorkerStatus, ) from osa.domain.shared.outbox import Outbox +from osa.infrastructure.logging import get_logger from osa.util.di.scope import Scope -logger = logging.getLogger(__name__) +logger = get_logger(__name__) @dataclass @@ -136,7 +137,7 @@ async def _run(self) -> None: logger.info(f"Worker '{self.name}' cancelled") raise except Exception as e: - logger.exception(f"Worker '{self.name}' crashed: {e}") + logger.error(f"Worker '{self.name}' crashed: {e}") self._state.error = e raise finally: @@ -171,8 +172,9 @@ async def _poll_once(self) -> bool: self._state.current_batch = result.events self._state.last_claim_at = result.claimed_at + handler = await scope.get(self._handler_type) + try: - handler = await scope.get(self._handler_type) events = result.events if self._batch_size > 1: @@ -187,9 +189,7 @@ async def _poll_once(self) -> bool: self._state.processed_count += len(result.deliveries) except SkippedEvents as e: - logger.warning( - f"Worker '{self.name}' skipping {len(e.event_ids)} events: {e.reason}" - ) + logger.warn(f"Worker '{self.name}' skipping {len(e.event_ids)} events: {e.reason}") skipped_set = set(e.event_ids) for delivery in result.deliveries: if delivery.event.id in skipped_set: @@ -198,16 +198,92 @@ async def _poll_once(self) -> bool: await outbox.mark_delivered(delivery.id) self._state.processed_count += len(result.deliveries) - len(e.event_ids) + except TransientError as e: + self._state.failed_count += len(result.deliveries) + self._state.error = e + for delivery in result.deliveries: + exhausted = delivery.retry_count + 1 >= self._max_retries + if exhausted: + logger.warn( + "Worker '{name}' transient retries exhausted: {error}", + name=self.name, + error=str(e), + ) + try: + await handler.on_exhausted(delivery.event) + except Exception as exhausted_err: + logger.error( + "Worker '{name}' on_exhausted failed: {error}", + name=self.name, + error=str(exhausted_err), + ) + await outbox.mark_failed(delivery.id, str(e)) + else: + backoff_seconds = min(300, 60 * (2**delivery.retry_count)) + deliver_after = datetime.now(UTC) + timedelta(seconds=backoff_seconds) + logger.warn( + "Worker '{name}' transient failure: {error} " + "(attempt={attempt}, next_retry=+{backoff}s)", + name=self.name, + error=str(e), + attempt=delivery.retry_count, + backoff=backoff_seconds, + ) + await outbox.mark_failed_with_retry( + delivery.id, + str(e), + max_retries=self._max_retries, + deliver_after=deliver_after, + ) + + except PermanentError as e: + self._state.failed_count += len(result.deliveries) + self._state.error = e + logger.error( + "Worker '{name}' permanent failure: {error}", + name=self.name, + error=str(e), + ) + for delivery in result.deliveries: + try: + await handler.on_exhausted(delivery.event) + except Exception as exhausted_err: + logger.error( + "Worker '{name}' on_exhausted failed: {error}", + name=self.name, + error=str(exhausted_err), + ) + await outbox.mark_failed(delivery.id, str(e)) + except Exception as e: self._state.failed_count += len(result.deliveries) self._state.error = e - logger.error(f"Worker '{self.name}' batch failed: {e}") + logger.error( + "Worker '{name}' batch failed: {error}", + name=self.name, + error=str(e), + ) for delivery in result.deliveries: - await outbox.mark_failed_with_retry( - delivery.id, - str(e), - max_retries=self._max_retries, - ) + exhausted = delivery.retry_count + 1 >= self._max_retries + if exhausted: + try: + await handler.on_exhausted(delivery.event) + except Exception as exhausted_err: + logger.error( + "Worker '{name}' on_exhausted failed: {error}", + name=self.name, + error=str(exhausted_err), + ) + await outbox.mark_failed(delivery.id, str(e)) + else: + backoff_seconds = min(30, 5 ** (delivery.retry_count + 1)) + deliver_after = datetime.now(UTC) + timedelta(seconds=backoff_seconds) + await outbox.mark_failed_with_retry( + delivery.id, + str(e), + max_retries=self._max_retries, + deliver_after=deliver_after, + ) finally: self._state.current_batch = [] @@ -406,7 +482,9 @@ async def _run_schedule(self, config: "ScheduleConfig") -> None: self._schedule_failures[config.id] = failures logger.error(f"Failed to run schedule {config.id} (failures: {failures}): {e}") if failures >= 5: - logger.critical(f"Schedule {config.id} has failed {failures} consecutive times") + logger.error( + f"CRITICAL: Schedule {config.id} has failed {failures} consecutive times" + ) async def __aenter__(self) -> "WorkerPool": """Start the pool as async context manager.""" diff --git a/server/osa/infrastructure/ingest/di.py b/server/osa/infrastructure/ingest/di.py index 5992072..fbc4c8f 100644 --- a/server/osa/infrastructure/ingest/di.py +++ b/server/osa/infrastructure/ingest/di.py @@ -50,7 +50,7 @@ def get_ingest_service( # Ingest storage — default (filesystem, for local/Docker) @provide(scope=Scope.APP) def get_ingest_storage(self, layout: StorageLayout) -> IngestStoragePort: - return FilesystemIngestStorage(layout=layout) # type: ignore[return-value] + return FilesystemIngestStorage(layout=layout) # Ingest storage — K8s (S3 via aioboto3, reuses S3Client from RunnerProvider) @provide(when=K8S, scope=Scope.APP) @@ -59,7 +59,7 @@ def get_ingest_storage_s3( ) -> IngestStoragePort: from osa.infrastructure.s3.ingest_storage import S3IngestStorage - return S3IngestStorage( # type: ignore[return-value] + return S3IngestStorage( s3=s3, layout=layout, data_mount_path=config.runner.k8s.data_mount_path, diff --git a/server/osa/infrastructure/k8s/errors.py b/server/osa/infrastructure/k8s/errors.py index 3387478..1d74523 100644 --- a/server/osa/infrastructure/k8s/errors.py +++ b/server/osa/infrastructure/k8s/errors.py @@ -3,27 +3,27 @@ Maps kubernetes-asyncio ApiException status codes to OSA error types. """ -from osa.domain.shared.error import ConfigurationError, InfrastructureError, OSAError +from osa.domain.shared.error import OSAError, PermanentError, TransientError def classify_api_error(exc: Exception) -> OSAError: """Classify a K8s API error by HTTP status code. - - 403 → ConfigurationError (RBAC misconfiguration, not retried) - - 404 → ConfigurationError (namespace/resource missing, not retried) - - 500, 503 → InfrastructureError (transient, retried by outbox) - - Other → InfrastructureError + - 403 → PermanentError (RBAC misconfiguration, not retried) + - 404 → PermanentError (namespace/resource missing, not retried) + - 500, 503 → TransientError (cluster pressure, retried with backoff) + - Other → TransientError """ status = getattr(exc, "status", 0) reason = getattr(exc, "reason", str(exc)) if status == 403: - return ConfigurationError( + return PermanentError( f"K8s RBAC permission denied: {reason}. " "Check ServiceAccount permissions for the OSA namespace." ) if status == 404: - return ConfigurationError( + return PermanentError( f"K8s resource not found: {reason}. Check that the namespace and resources exist." ) - return InfrastructureError(f"K8s API error ({status}): {reason}") + return TransientError(f"K8s API error ({status}): {reason}") diff --git a/server/osa/infrastructure/k8s/ingester_runner.py b/server/osa/infrastructure/k8s/ingester_runner.py index fa29ec1..721739c 100644 --- a/server/osa/infrastructure/k8s/ingester_runner.py +++ b/server/osa/infrastructure/k8s/ingester_runner.py @@ -4,17 +4,22 @@ import asyncio import json -import logging import time from pathlib import Path from typing import TYPE_CHECKING from osa.config import K8sConfig -from osa.domain.shared.error import ExternalServiceError, InfrastructureError +from osa.domain.shared.error import ( + InfrastructureError, + OOMError, + PermanentError, + TransientError, +) from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.k8s.errors import classify_api_error +from osa.infrastructure.logging import get_logger from osa.infrastructure.k8s.naming import job_name, label_value, sanitize_label from osa.infrastructure.runner_utils import ( relative_path, @@ -22,11 +27,11 @@ ) if TYPE_CHECKING: - from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job + from kubernetes_asyncio.client import ApiClient, V1Job from osa.infrastructure.s3.client import S3Client -logger = logging.getLogger(__name__) +logger = get_logger(__name__) SCHEDULING_TIMEOUT = 120 @@ -40,11 +45,13 @@ class K8sIngesterRunner(IngesterRunner): - Three volume mounts: input (ro), output (rw), files (rw) - Higher resource defaults (3600s, 4g) - Source-specific env vars (OSA_FILES, OSA_SINCE, etc.) - - Errors raise ExternalServiceError (not returned as result values) """ def __init__(self, api_client: ApiClient, config: K8sConfig, s3: S3Client) -> None: - self._api_client = api_client + from kubernetes_asyncio.client import BatchV1Api, CoreV1Api + + self._batch_api = BatchV1Api(api_client) + self._core_api = CoreV1Api(api_client) self._config = config self._s3 = s3 @@ -52,6 +59,47 @@ def _s3_prefix(self, work_dir: Path, subdir: str) -> str: """Convert a PVC path + subdir to an S3 key prefix.""" return f"{relative_path(work_dir, self._config.data_mount_path)}/{subdir}" + async def has_capacity(self) -> bool: + """Check for unschedulable pods in the namespace. + + Only triggers backpressure when a pod has PodScheduled=False with + reason=Unschedulable, meaning the cluster genuinely can't place it. + Pods that are Pending but actively scheduling (image pull, node + assignment) are normal and should not block ingestion. + """ + namespace = self._config.namespace + try: + pod_list = await self._core_api.list_namespaced_pod( + namespace, field_selector="status.phase=Pending" + ) + for pod in pod_list.items: + for condition in pod.status.conditions or []: + if condition.type == "PodScheduled" and condition.reason == "Unschedulable": + return False + return True + except Exception as e: + logger.warn( + "Failed to check cluster capacity: {error} — assuming capacity", error=str(e) + ) + return True + + async def capture_logs(self, run_id: str) -> str: + """Capture recent pod logs for an ingester Job identified by run_id.""" + namespace = self._config.namespace + label_selector = f"osa.io/role=ingester,osa.io/ingest-run-id={run_id}" + try: + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) + for pod in pod_list.items: + log_str = await self._core_api.read_namespaced_pod_log( + pod.metadata.name, namespace, tail_lines=10 + ) + return log_str.strip() if log_str else "" + except Exception: + return "" + return "" + async def run( self, ingester: IngesterDefinition, @@ -59,18 +107,6 @@ async def run( files_dir: Path, work_dir: Path, ) -> IngesterOutput: - try: - from kubernetes_asyncio.client import BatchV1Api, CoreV1Api - except ImportError: - from osa.domain.shared.error import ConfigurationError - - raise ConfigurationError( - "kubernetes-asyncio is required for K8s runner. Install with: pip install osa[k8s]" - ) - - batch_api = BatchV1Api(self._api_client) - core_api = CoreV1Api(self._api_client) - # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") @@ -81,26 +117,14 @@ async def run( if inputs.session: await self._s3.put_object(f"{input_prefix}/session.json", json.dumps(inputs.session)) - return await self._run_job( - batch_api, - core_api, - ingester, - inputs, - work_dir, - files_dir, - convention_srn=inputs.convention_srn, - ) + return await self._run_job(ingester, inputs, work_dir, files_dir) async def _run_job( self, - batch_api: BatchV1Api, - core_api: CoreV1Api, ingester: IngesterDefinition, inputs: IngesterInputs, work_dir: Path, files_dir: Path, - *, - convention_srn: ConventionSRN | None = None, ) -> IngesterOutput: """Core Job lifecycle for ingester execution.""" namespace = self._config.namespace @@ -109,65 +133,78 @@ async def _run_job( try: # Check for existing Jobs existing = await self._check_existing_job( - batch_api, namespace, convention_srn, ingester.digest + namespace, inputs.convention_srn, ingester.digest ) if existing == "succeeded": + logger.info("Reusing output from completed ingester Job") return await self._parse_source_output(work_dir, files_dir) if existing and existing.startswith("active:"): job_name_to_watch = existing.split(":", 1)[1] + logger.info( + "Attaching to running ingester Job: {job_name}", + job_name=job_name_to_watch, + ) else: + # Clear stale output and files from previous failed runs + output_prefix = self._s3_prefix(work_dir, "output") + await self._s3.delete_objects(output_prefix) + files_prefix = relative_path(files_dir, self._config.data_mount_path) + await self._s3.delete_objects(files_prefix) + spec = self._build_job_spec( ingester, work_dir=work_dir, files_dir=files_dir, inputs=inputs, - convention_srn=convention_srn, + convention_srn=inputs.convention_srn, ) job_name_to_watch = spec.metadata.name - await batch_api.create_namespaced_job(namespace, spec) + await self._batch_api.create_namespaced_job(namespace, spec) logger.info( - "Created K8s ingester Job", - extra={ - "job_name": job_name_to_watch, - "namespace": namespace, - "image": f"{ingester.image}@{ingester.digest}", - }, + "Created K8s ingester Job: {job_name}", + job_name=job_name_to_watch, + namespace=namespace, + image=f"{ingester.image}@{ingester.digest}", ) # Phase 1: Scheduling - await self._wait_for_scheduling(core_api, job_name_to_watch, namespace) + await self._wait_for_scheduling(job_name_to_watch, namespace) - # Phase 2: Completion - result = await self._wait_for_completion( - batch_api, + # Phase 2: Completion (raises on failure) + await self._wait_for_completion( job_name_to_watch, namespace, timeout_seconds=ingester.limits.timeout_seconds + 30, ) - if result == "succeeded": - output = await self._parse_source_output(work_dir, files_dir) - logger.info( - "Source completed", - extra={ - "job_name": job_name_to_watch, - "record_count": len(output.records), - "has_session": output.session is not None, - }, - ) - return output + output = await self._parse_source_output(work_dir, files_dir) + logger.info( + "Ingester batch completed: {job_name} ({record_count} records)", + job_name=job_name_to_watch, + record_count=len(output.records), + has_session=output.session is not None, + ) + return output - # Failed — diagnose and raise - await self._diagnose_and_raise(core_api, job_name_to_watch, namespace, ingester, result) - # unreachable but satisfies type checker - raise ExternalServiceError("Source failed") + except InfrastructureError as e: + # Capture logs before cleanup destroys the pod + if job_name_to_watch: + logs = await self._capture_pod_logs(job_name_to_watch, namespace) + e.container_logs = logs + if logs: + logger.error( + "Job {job_name} failed — container logs:\n{logs}", + job_name=job_name_to_watch, + logs=logs, + ) + raise finally: if job_name_to_watch: - await self._cleanup_job(batch_api, job_name_to_watch, namespace) + await self._cleanup_job(job_name_to_watch, namespace) async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> IngesterOutput: from osa.infrastructure.runner_utils import ( @@ -182,7 +219,6 @@ async def _parse_source_output(self, work_dir: Path, files_dir: Path) -> Ingeste async def _check_existing_job( self, - batch_api: BatchV1Api, namespace: str, convention_srn: ConventionSRN | None, digest: str = "", @@ -195,7 +231,9 @@ async def _check_existing_job( label_selector = ",".join(label_parts) try: - job_list = await batch_api.list_namespaced_job(namespace, label_selector=label_selector) + job_list = await self._batch_api.list_namespaced_job( + namespace, label_selector=label_selector + ) except Exception as exc: raise classify_api_error(exc) from exc @@ -246,6 +284,9 @@ def _build_job_spec( } if convention_srn is not None: labels["osa.io/convention"] = label_value(convention_srn) + if inputs and inputs.ingest_run_id: + labels["osa.io/ingest-run-id"] = inputs.ingest_run_id + labels["osa.io/ingest-run-batch"] = str(inputs.batch_index) env = [ V1EnvVar(name="OSA_IN", value="/osa/in"), @@ -334,7 +375,6 @@ def _relative_path(self, path: Path) -> str: async def _wait_for_scheduling( self, - core_api: CoreV1Api, job_name: str, namespace: str, *, @@ -346,7 +386,7 @@ async def _wait_for_scheduling( while time.monotonic() < deadline: try: - pod_list = await core_api.list_namespaced_pod( + pod_list = await self._core_api.list_namespaced_pod( namespace, label_selector=label_selector ) except Exception as exc: @@ -356,13 +396,13 @@ async def _wait_for_scheduling( phase = pod.status.phase if phase == "Failed": reason = getattr(pod.status, "reason", None) or "Unknown" - raise InfrastructureError(f"Pod failed during scheduling: {reason}") + raise TransientError(f"Pod failed during scheduling: {reason}") if phase == "Pending" and pod.status.container_statuses: for cs in pod.status.container_statuses: waiting = getattr(cs.state, "waiting", None) if waiting and waiting.reason in ("ImagePullBackOff", "ErrImagePull"): - raise InfrastructureError( + raise PermanentError( f"Image pull failed: {waiting.reason}: {getattr(waiting, 'message', '')}" ) @@ -371,92 +411,106 @@ async def _wait_for_scheduling( await asyncio.sleep(poll_interval) - raise InfrastructureError( - f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}" - ) + raise TransientError(f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}") async def _wait_for_completion( self, - batch_api: BatchV1Api, job_name: str, namespace: str, *, timeout_seconds: float = 3630, poll_interval: float = 5.0, - ) -> str: + ) -> None: + """Wait for Job to complete. Returns on success, raises on failure.""" deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: try: - job = await batch_api.read_namespaced_job(job_name, namespace) + job = await self._batch_api.read_namespaced_job(job_name, namespace) except Exception as exc: raise classify_api_error(exc) from exc if job.status.succeeded: - return "succeeded" + return if job.status.conditions: for condition in job.status.conditions: if condition.type == "Failed" and condition.status == "True": - return f"failed:{getattr(condition, 'reason', 'Unknown')}" + failure_reason = getattr(condition, "reason", "Unknown") + raise await self._diagnose_failure(job_name, namespace, failure_reason) if condition.type == "Complete" and condition.status == "True": - return "succeeded" + return if job.status.failed: - return "failed:BackoffLimitExceeded" + raise await self._diagnose_failure(job_name, namespace, "BackoffLimitExceeded") await asyncio.sleep(poll_interval) # Timed out — poll once more to catch last-millisecond completions try: - job = await batch_api.read_namespaced_job(job_name, namespace) + job = await self._batch_api.read_namespaced_job(job_name, namespace) if job.status.succeeded: - return "succeeded" + return except Exception: pass - return "failed:WatchTimeout" + raise TransientError(f"Watch timeout waiting for ingester Job {job_name} completion") + + async def _capture_pod_logs(self, job_name: str, namespace: str) -> str: + """Capture tail logs from a Job's pod. Returns empty if unavailable.""" + try: + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=f"job-name={job_name}" + ) + for pod in pod_list.items: + log_str = await self._core_api.read_namespaced_pod_log( + pod.metadata.name, namespace, tail_lines=10 + ) + return log_str.strip() if log_str else "" + except Exception: + return "" + return "" - async def _diagnose_and_raise( + async def _diagnose_failure( self, - core_api: CoreV1Api, job_name: str, namespace: str, - ingester: IngesterDefinition, failure_info: str, - ) -> None: - """Determine failure reason and raise appropriate error.""" + ) -> InfrastructureError: + """Inspect pod status, capture logs, and return the appropriate exception.""" if "DeadlineExceeded" in failure_info: - raise ExternalServiceError( - f"Ingester timed out after {ingester.limits.timeout_seconds}s" - ) + return TransientError("Ingester timed out (deadline exceeded)") try: label_selector = f"job-name={job_name}" - pod_list = await core_api.list_namespaced_pod(namespace, label_selector=label_selector) + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) for pod in pod_list.items: if pod.status.container_statuses: for cs in pod.status.container_statuses: terminated = getattr(cs.state, "terminated", None) if terminated: if getattr(terminated, "reason", None) == "OOMKilled": - raise ExternalServiceError("Source killed by OOM") + return OOMError("Source killed by OOM") exit_code = getattr(terminated, "exit_code", -1) if exit_code != 0: - raise ExternalServiceError(f"Source exited with code {exit_code}") - except ExternalServiceError: - raise + # Transient: ingester non-zero exit is often an upstream + # API failure (500, rate limit), not a code bug. + # Contrast with hooks where non-zero = PermanentError. + return TransientError(f"Source exited with code {exit_code}") except Exception: pass - raise ExternalServiceError(f"Source failed: {failure_info}") + return PermanentError(f"Source failed: {failure_info}") - async def _cleanup_job(self, batch_api: BatchV1Api, job_name: str, namespace: str) -> None: + async def _cleanup_job(self, job_name: str, namespace: str) -> None: try: - await batch_api.delete_namespaced_job( + await self._batch_api.delete_namespaced_job( job_name, namespace, propagation_policy="Background", ) + logger.info("Cleaned up K8s ingester Job: {job_name}", job_name=job_name) except Exception as exc: if getattr(exc, "status", None) == 404: return - logger.warning("Failed to clean up K8s ingester Job", extra={"job_name": job_name}) + logger.warn("Failed to clean up K8s ingester Job: {job_name}", job_name=job_name) diff --git a/server/osa/infrastructure/k8s/runner.py b/server/osa/infrastructure/k8s/runner.py index 6d5ca8a..23e59b7 100644 --- a/server/osa/infrastructure/k8s/runner.py +++ b/server/osa/infrastructure/k8s/runner.py @@ -4,18 +4,23 @@ import asyncio import json -import logging import time from pathlib import Path from typing import TYPE_CHECKING from osa.config import K8sConfig -from osa.domain.shared.error import InfrastructureError +from osa.domain.shared.error import ( + InfrastructureError, + OOMError, + PermanentError, + TransientError, +) from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner from osa.infrastructure.k8s.errors import classify_api_error -from osa.infrastructure.k8s.naming import job_name, label_value +from osa.infrastructure.logging import get_logger +from osa.infrastructure.k8s.naming import job_name from osa.infrastructure.runner_utils import ( detect_rejection, relative_path, @@ -23,11 +28,11 @@ ) if TYPE_CHECKING: - from kubernetes_asyncio.client import ApiClient, BatchV1Api, CoreV1Api, V1Job + from kubernetes_asyncio.client import ApiClient, V1Job from osa.infrastructure.s3.client import S3Client -logger = logging.getLogger(__name__) +logger = get_logger(__name__) SCHEDULING_TIMEOUT = 120 # seconds to wait for pod to leave Pending @@ -43,7 +48,10 @@ class K8sHookRunner(HookRunner): """ def __init__(self, api_client: ApiClient, config: K8sConfig, s3: S3Client) -> None: - self._api_client = api_client + from kubernetes_asyncio.client import BatchV1Api, CoreV1Api + + self._batch_api = BatchV1Api(api_client) + self._core_api = CoreV1Api(api_client) self._config = config self._s3 = s3 @@ -51,24 +59,34 @@ def _s3_prefix(self, work_dir: Path, subdir: str) -> str: """Convert a PVC path + subdir to an S3 key prefix.""" return f"{relative_path(work_dir, self._config.data_mount_path)}/{subdir}" + async def capture_logs(self, run_id: str) -> str: + """Capture recent pod logs for a hook Job identified by run_id.""" + namespace = self._config.namespace + # Find the Job by label — run_id format is {uuid}_b{batch} + ingest_run_id = run_id.split("_b", 1)[0] + batch_index = run_id.split("_b", 1)[1] if "_b" in run_id else "0" + label_selector = ( + f"osa.io/ingest-run-id={ingest_run_id},osa.io/ingest-run-batch={batch_index}" + ) + try: + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) + for pod in pod_list.items: + log_str = await self._core_api.read_namespaced_pod_log( + pod.metadata.name, namespace, tail_lines=10 + ) + return log_str.strip() if log_str else "" + except Exception: + return "" + return "" + async def run( self, hook: HookDefinition, inputs: HookInputs, work_dir: Path, ) -> HookResult: - try: - from kubernetes_asyncio.client import BatchV1Api, CoreV1Api - except ImportError: - from osa.domain.shared.error import ConfigurationError - - raise ConfigurationError( - "kubernetes-asyncio is required for K8s runner. Install with: pip install osa[k8s]" - ) - - batch_api = BatchV1Api(self._api_client) - core_api = CoreV1Api(self._api_client) - # Write input files to S3 (container reads them via PVC/S3 CSI) input_prefix = self._s3_prefix(work_dir, "input") # Write records.jsonl (unified batch contract) @@ -78,18 +96,10 @@ async def run( config = {**hook.runtime.config, **(inputs.config or {})} await self._s3.put_object(f"{input_prefix}/config.json", json.dumps(config)) - return await self._run_job( - batch_api, - core_api, - hook, - inputs, - work_dir, - ) + return await self._run_job(hook, inputs, work_dir) async def _run_job( self, - batch_api: BatchV1Api, - core_api: CoreV1Api, hook: HookDefinition, inputs: HookInputs, work_dir: Path, @@ -102,17 +112,24 @@ async def _run_job( job_name_to_watch = None try: - existing = await self._check_existing_job( - batch_api, namespace, hook.name, inputs.run_id - ) + existing = await self._check_existing_job(namespace, hook.name, inputs.run_id) if existing == "succeeded": - # Read output from completed Job + logger.info( + "Reusing output from completed hook Job (hook={hook_name}, run_id={run_id})", + hook_name=hook.name, + run_id=inputs.run_id, + ) return await self._parse_hook_result(hook, work_dir, start_time) if existing and existing.startswith("active:"): # Attach to running Job job_name_to_watch = existing.split(":", 1)[1] + logger.info( + "Attaching to running hook Job: {job_name} (hook={hook_name})", + job_name=job_name_to_watch, + hook_name=hook.name, + ) else: # Create new Job (no existing or failed) # Mount the parent of all per-record file dirs — works for @@ -129,40 +146,42 @@ async def _run_job( ) job_name_to_watch = spec.metadata.name - await batch_api.create_namespaced_job(namespace, spec) + await self._batch_api.create_namespaced_job(namespace, spec) logger.info( - "Created K8s Job", - extra={ - "job_name": job_name_to_watch, - "namespace": namespace, - "image": f"{hook.runtime.image}@{hook.runtime.digest}", - "hook_name": hook.name, - "run_id": inputs.run_id, - }, + "Created K8s hook Job: {job_name} (hook={hook_name}, run_id={run_id})", + job_name=job_name_to_watch, + hook_name=hook.name, + run_id=inputs.run_id, ) # Phase 1: Wait for scheduling - await self._wait_for_scheduling(core_api, job_name_to_watch, namespace) + await self._wait_for_scheduling(job_name_to_watch, namespace) - # Phase 2: Wait for completion - result = await self._wait_for_completion( - batch_api, + # Phase 2: Wait for completion (raises on failure) + await self._wait_for_completion( job_name_to_watch, namespace, timeout_seconds=hook.runtime.limits.timeout_seconds + 30, ) - if result == "succeeded": - return await self._parse_hook_result(hook, work_dir, start_time) + return await self._parse_hook_result(hook, work_dir, start_time) - # Job failed — determine why - return await self._diagnose_failure( - core_api, job_name_to_watch, namespace, hook, start_time, result - ) + except InfrastructureError as e: + # Capture logs before cleanup destroys the pod + if job_name_to_watch: + logs = await self._capture_pod_logs(job_name_to_watch, namespace) + e.container_logs = logs + if logs: + logger.error( + "Job {job_name} failed — container logs:\n{logs}", + job_name=job_name_to_watch, + logs=logs, + ) + raise finally: if job_name_to_watch: - await self._cleanup_job(batch_api, job_name_to_watch, namespace) + await self._cleanup_job(job_name_to_watch, namespace) async def _parse_hook_result( self, hook: HookDefinition, work_dir: Path, start_time: float @@ -193,7 +212,6 @@ async def _parse_hook_result( async def _check_existing_job( self, - batch_api: BatchV1Api, namespace: str, hook_name: str, run_id: str, @@ -205,9 +223,13 @@ async def _check_existing_job( "active:{job_name}" if a running Job exists None if no Job or only failed Jobs exist """ - label_selector = f"osa.io/hook={hook_name},osa.io/run-id={label_value(run_id)}" + ingest_run_id = run_id.split("_b", 1)[0] + batch_index = run_id.split("_b", 1)[1] if "_b" in run_id else "0" + label_selector = f"osa.io/hook={hook_name},osa.io/ingest-run-id={ingest_run_id},osa.io/ingest-run-batch={batch_index}" try: - job_list = await batch_api.list_namespaced_job(namespace, label_selector=label_selector) + job_list = await self._batch_api.list_namespaced_job( + namespace, label_selector=label_selector + ) except Exception as exc: raise classify_api_error(exc) from exc @@ -254,10 +276,13 @@ def _build_job_spec( input_subpath = f"{relative_work}/input" output_subpath = f"{relative_work}/output" + ingest_run_id = run_id.split("_b", 1)[0] + batch_index = run_id.split("_b", 1)[1] if "_b" in run_id else "0" labels = { "osa.io/role": "hook", "osa.io/hook": hook.name, - "osa.io/run-id": label_value(run_id), + "osa.io/ingest-run-id": ingest_run_id, + "osa.io/ingest-run-batch": batch_index, } mounts = [ @@ -355,7 +380,6 @@ def _relative_path(self, path: Path) -> str: async def _wait_for_scheduling( self, - core_api: CoreV1Api, job_name: str, namespace: str, *, @@ -368,7 +392,7 @@ async def _wait_for_scheduling( while time.monotonic() < deadline: try: - pod_list = await core_api.list_namespaced_pod( + pod_list = await self._core_api.list_namespaced_pod( namespace, label_selector=label_selector ) except Exception as exc: @@ -380,7 +404,7 @@ async def _wait_for_scheduling( # Check for eviction if phase == "Failed": reason = getattr(pod.status, "reason", None) or "Unknown" - raise InfrastructureError(f"Pod evicted or failed during scheduling: {reason}") + raise TransientError(f"Pod evicted or failed during scheduling: {reason}") # Check for image pull errors if phase == "Pending" and pod.status.container_statuses: @@ -388,135 +412,121 @@ async def _wait_for_scheduling( waiting = getattr(cs.state, "waiting", None) if waiting and waiting.reason in ("ImagePullBackOff", "ErrImagePull"): message = getattr(waiting, "message", "") - raise InfrastructureError( - f"Image pull failed: {waiting.reason}: {message}" - ) + raise PermanentError(f"Image pull failed: {waiting.reason}: {message}") if phase in ("Running", "Succeeded", "Failed"): return # Pod scheduled await asyncio.sleep(poll_interval) - raise InfrastructureError( - f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}" - ) + raise TransientError(f"Pod scheduling timeout after {timeout_seconds}s for Job {job_name}") async def _wait_for_completion( self, - batch_api: BatchV1Api, job_name: str, namespace: str, *, timeout_seconds: float = 330, poll_interval: float = 5.0, - ) -> str: - """Wait for Job to complete (Phase 2). Returns 'succeeded' or 'failed'.""" + ) -> None: + """Wait for Job to complete (Phase 2). Returns on success, raises on failure.""" deadline = time.monotonic() + timeout_seconds while time.monotonic() < deadline: try: - job = await batch_api.read_namespaced_job(job_name, namespace) + job = await self._batch_api.read_namespaced_job(job_name, namespace) except Exception as exc: raise classify_api_error(exc) from exc if job.status.succeeded: - return "succeeded" + return if job.status.conditions: for condition in job.status.conditions: if condition.type == "Failed" and condition.status == "True": - return f"failed:{getattr(condition, 'reason', 'Unknown')}" + failure_reason = getattr(condition, "reason", "Unknown") + raise await self._diagnose_failure(job_name, namespace, failure_reason) if condition.type == "Complete" and condition.status == "True": - return "succeeded" + return if job.status.failed: - return "failed:BackoffLimitExceeded" + raise await self._diagnose_failure(job_name, namespace, "BackoffLimitExceeded") await asyncio.sleep(poll_interval) # Timed out — poll once more try: - job = await batch_api.read_namespaced_job(job_name, namespace) + job = await self._batch_api.read_namespaced_job(job_name, namespace) if job.status.succeeded: - return "succeeded" + return except Exception: pass - return "failed:WatchTimeout" + raise TransientError(f"Watch timeout waiting for Job {job_name} completion") + + async def _capture_pod_logs(self, job_name: str, namespace: str) -> str: + """Capture tail logs from a Job's pod. Returns empty if unavailable.""" + try: + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=f"job-name={job_name}" + ) + for pod in pod_list.items: + log_str = await self._core_api.read_namespaced_pod_log( + pod.metadata.name, namespace, tail_lines=10 + ) + return log_str.strip() if log_str else "" + except Exception: + return "" + return "" async def _diagnose_failure( self, - core_api: CoreV1Api, job_name: str, namespace: str, - hook: HookDefinition, - start_time: float, failure_info: str, - ) -> HookResult: - """Determine failure reason from pod status.""" - duration = time.monotonic() - start_time - - # Check if DeadlineExceeded + ) -> InfrastructureError: + """Inspect pod status, capture logs, and return the appropriate exception.""" if "DeadlineExceeded" in failure_info: - return HookResult( - hook_name=hook.name, - status=HookStatus.FAILED, - error_message="Hook timed out (deadline exceeded)", - duration_seconds=duration, - ) + return TransientError("Hook timed out (deadline exceeded)") - # Check pod for OOM or exit code try: label_selector = f"job-name={job_name}" - pod_list = await core_api.list_namespaced_pod(namespace, label_selector=label_selector) + pod_list = await self._core_api.list_namespaced_pod( + namespace, label_selector=label_selector + ) for pod in pod_list.items: if pod.status.container_statuses: for cs in pod.status.container_statuses: terminated = getattr(cs.state, "terminated", None) if terminated: if getattr(terminated, "reason", None) == "OOMKilled": - return HookResult( - hook_name=hook.name, - status=HookStatus.OOM, - error_message="Hook killed by OOM", - duration_seconds=duration, - ) + return OOMError("Hook killed by OOM") exit_code = getattr(terminated, "exit_code", -1) if exit_code != 0: - return HookResult( - hook_name=hook.name, - status=HookStatus.FAILED, - error_message=f"Hook exited with code {exit_code}", - duration_seconds=duration, - ) + return PermanentError(f"Hook exited with code {exit_code}") except Exception: pass - return HookResult( - hook_name=hook.name, - status=HookStatus.FAILED, - error_message=f"Hook failed: {failure_info}", - duration_seconds=duration, - ) + return PermanentError(f"Hook failed: {failure_info}") async def _cleanup_job( self, - batch_api: BatchV1Api, job_name: str, namespace: str, ) -> None: """Delete a Job and its pods. Ignores 404 (already cleaned up).""" try: - await batch_api.delete_namespaced_job( + await self._batch_api.delete_namespaced_job( job_name, namespace, propagation_policy="Background", ) - logger.info("Cleaned up K8s Job", extra={"job_name": job_name}) + logger.info("Cleaned up K8s hook Job: {job_name}", job_name=job_name) except Exception as exc: if getattr(exc, "status", None) == 404: return # Already gone - logger.warning( - "Failed to clean up K8s Job", - extra={"job_name": job_name, "error": str(exc)}, + logger.warn( + "Failed to clean up K8s hook Job: {job_name} ({error})", + job_name=job_name, + error=str(exc), ) diff --git a/server/osa/infrastructure/logging.py b/server/osa/infrastructure/logging.py index 90070b0..e0de990 100644 --- a/server/osa/infrastructure/logging.py +++ b/server/osa/infrastructure/logging.py @@ -114,16 +114,13 @@ def _shorten_module(name: str) -> str: ) if len(short) <= _MODULE_WIDTH: return short - # Truncate: keep first and last segment, abbreviate middle + # SLF4J-style: abbreviate all segments except the last to first char parts = short.split(".") - if len(parts) <= 2: + if len(parts) == 1: return short[:_MODULE_WIDTH] - # Keep first and last, drop middle segments until it fits - first, *middle, last = parts - while middle and len(f"{first}.{'.'.join(middle)}.{last}") > _MODULE_WIDTH: - middle.pop(0) - result = f"{first}.{'.'.join(middle)}.{last}" if middle else f"{first}.{last}" - return result[:_MODULE_WIDTH] + *prefixes, last = parts + abbreviated = ".".join(p[0] for p in prefixes) + "." + last + return abbreviated[:_MODULE_WIDTH] class Logger: diff --git a/server/osa/infrastructure/oci/ingester_runner.py b/server/osa/infrastructure/oci/ingester_runner.py index bb926ee..e458d85 100644 --- a/server/osa/infrastructure/oci/ingester_runner.py +++ b/server/osa/infrastructure/oci/ingester_runner.py @@ -8,9 +8,8 @@ from pathlib import Path import aiodocker -import logfire - -from osa.domain.shared.error import ExternalServiceError +from osa.domain.shared.error import OOMError, TransientError +from osa.infrastructure.logging import get_logger from osa.domain.shared.model.source import IngesterDefinition from osa.domain.shared.port.ingester_runner import IngesterInputs, IngesterOutput, IngesterRunner from osa.infrastructure.runner_utils import ( @@ -20,6 +19,9 @@ ) +log = get_logger(__name__) + + class OciIngesterRunner(IngesterRunner): """Executes ingesters in OCI containers via aiodocker. @@ -45,6 +47,14 @@ def __init__( self._host_data_dir = host_data_dir self._container_data_dir = container_data_dir + async def has_capacity(self) -> bool: + """Docker doesn't have scheduling contention.""" + return True + + async def capture_logs(self, run_id: str) -> str: + """OCI containers are deleted after run — logs captured inline during execution.""" + return "" + async def run( self, ingester: IngesterDefinition, @@ -92,13 +102,13 @@ async def _resolve_and_run(): return result except asyncio.TimeoutError: duration = time.monotonic() - start_time - logfire.error( - "Ingester timed out", + log.error( + "Ingester timed out after {timeout}s", image=ingester.image, timeout=timeout, duration=duration, ) - raise ExternalServiceError(f"Ingester timed out after {timeout}s") + raise TransientError(f"Ingester timed out after {timeout}s") finally: rmtree(staging_dir, onexc=_force_remove) @@ -159,22 +169,26 @@ async def _run_container( oom_killed = inspect_data.get("State", {}).get("OOMKilled", False) if oom_killed: - raise ExternalServiceError("Ingester killed by OOM") + raise OOMError("Ingester killed by OOM") if exit_code != 0: logs = await container.log(stdout=True, stderr=True) logs_str = "".join(logs) if logs else "" - raise ExternalServiceError( - f"Ingester exited with code {exit_code}: {logs_str[:500]}" + log.error( + "Ingester exited with code {exit_code}", + exit_code=exit_code, + image=ingester.image, + container_logs=logs_str[:2000], ) + raise TransientError(f"Ingester exited with code {exit_code}") records = parse_records_file(output_dir) session = parse_session_file(output_dir) return IngesterOutput(records=records, session=session, files_dir=files_dir) except aiodocker.DockerError as e: - logfire.error("Docker error running ingester", error=str(e)) - raise ExternalServiceError(f"Docker error: {e}") from e + log.error("Docker error running ingester: {error}", error=str(e)) + raise TransientError(f"Docker error: {e}") from e finally: if container is not None: try: @@ -212,6 +226,6 @@ async def _resolve_image(self, image: str, digest: str) -> str: pass # Pull from registry as last resort - logfire.info("Pulling ingester image", image=image) + log.info("Pulling ingester image: {image}", image=image) await self._docker.images.pull(image) return image diff --git a/server/osa/infrastructure/oci/runner.py b/server/osa/infrastructure/oci/runner.py index 05414d3..f114be4 100644 --- a/server/osa/infrastructure/oci/runner.py +++ b/server/osa/infrastructure/oci/runner.py @@ -10,6 +10,7 @@ from shutil import rmtree import aiodocker +from osa.domain.shared.error import OOMError, PermanentError, TransientError from osa.domain.shared.model.hook import HookDefinition from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs, HookRunner @@ -43,6 +44,10 @@ def __init__( self._host_data_dir = host_data_dir self._container_data_dir = container_data_dir + async def capture_logs(self, run_id: str) -> str: + """OCI containers are deleted after run — logs captured inline during execution.""" + return "" + async def run( self, hook: HookDefinition, @@ -99,19 +104,13 @@ async def _resolve_and_run(): duration_seconds=result_duration, ) except asyncio.TimeoutError: - duration = time.monotonic() - start_time log.error( "Hook timed out", hook=hook.name, run_id=inputs.run_id, timeout=timeout, ) - return HookResult( - hook_name=hook.name, - status=HookStatus.FAILED, - error_message=f"Hook timed out after {timeout}s", - duration_seconds=duration, - ) + raise TransientError(f"Hook timed out after {timeout}s") finally: rmtree(staging_dir, onexc=_force_remove) @@ -191,10 +190,7 @@ async def _run_container( if tail_text: for line in tail_text.splitlines(): print(f" OOM [{hook.name}] {line}", file=sys.stderr, flush=True) - return { - "status": HookStatus.OOM, - "error_message": f"Hook killed by OOM (limit: {hook.runtime.limits.memory})", - } + raise OOMError(f"Hook killed by OOM (limit: {hook.runtime.limits.memory})") # Parse progress file progress = parse_progress_file(output_dir) @@ -211,29 +207,21 @@ async def _run_container( if exit_code != 0: logs = await container.log(stdout=True, stderr=True) logs_str = "".join(logs) if logs else "" - return { - "status": HookStatus.FAILED, - "error_message": f"Hook exited with code {exit_code}: {logs_str[:2000]}", - "progress": progress, - } + raise PermanentError(f"Hook exited with code {exit_code}: {logs_str[:2000]}") return { "status": HookStatus.PASSED, "progress": progress, } + except (OOMError, PermanentError): + raise except aiodocker.DockerError as e: log.error("Docker error running hook", error=str(e)) - return { - "status": HookStatus.FAILED, - "error_message": f"Docker error: {e}", - } + raise TransientError(f"Docker error: {e}") from e except Exception as e: log.error("Unexpected error running hook", error=str(e)) - return { - "status": HookStatus.FAILED, - "error_message": f"Unexpected error: {e}", - } + raise TransientError(f"Unexpected error: {e}") from e finally: if container is not None: try: diff --git a/server/osa/infrastructure/persistence/adapter/ingest_storage.py b/server/osa/infrastructure/persistence/adapter/ingest_storage.py index 57db091..44bf535 100644 --- a/server/osa/infrastructure/persistence/adapter/ingest_storage.py +++ b/server/osa/infrastructure/persistence/adapter/ingest_storage.py @@ -18,14 +18,14 @@ class FilesystemIngestStorage: def __init__(self, layout: StorageLayout) -> None: self._layout = layout - async def read_session(self, ingest_run_srn: str) -> dict[str, Any] | None: - session_file = self._layout.ingest_session_file(ingest_run_srn) + async def read_session(self, ingest_run_id: str) -> dict[str, Any] | None: + session_file = self._layout.ingest_session_file(ingest_run_id) if not session_file.exists(): return None return json.loads(session_file.read_text()) - async def write_session(self, ingest_run_srn: str, session: dict[str, Any]) -> None: - session_file = self._layout.ingest_session_file(ingest_run_srn) + async def write_session(self, ingest_run_id: str, session: dict[str, Any]) -> None: + session_file = self._layout.ingest_session_file(ingest_run_id) session_file.parent.mkdir(parents=True, exist_ok=True) # Atomic write via temp file + os.replace to handle mountpoint-for-s3 tmp = session_file.with_suffix(".tmp") @@ -33,9 +33,9 @@ async def write_session(self, ingest_run_srn: str, session: dict[str, Any]) -> N os.replace(tmp, session_file) async def write_records( - self, ingest_run_srn: str, batch_index: int, records: list[dict[str, Any]] + self, ingest_run_id: str, batch_index: int, records: list[dict[str, Any]] ) -> None: - ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) ingester_dir.mkdir(parents=True, exist_ok=True) records_file = ingester_dir / "records.jsonl" tmp = records_file.with_suffix(".tmp") @@ -44,8 +44,8 @@ async def write_records( f.write(json.dumps(record) + "\n") os.replace(tmp, records_file) - async def read_records(self, ingest_run_srn: str, batch_index: int) -> list[dict[str, Any]]: - ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + async def read_records(self, ingest_run_id: str, batch_index: int) -> list[dict[str, Any]]: + ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) records_file = ingester_dir / "records.jsonl" if not records_file.exists(): return [] @@ -57,22 +57,22 @@ async def read_records(self, ingest_run_srn: str, batch_index: int) -> list[dict records.append(json.loads(line)) return records - def batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - d = self._layout.ingest_batch_dir(ingest_run_srn, batch_index) + def batch_dir(self, ingest_run_id: str, batch_index: int) -> Path: + d = self._layout.ingest_batch_dir(ingest_run_id, batch_index) d.mkdir(parents=True, exist_ok=True) return d - def batch_work_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - d = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + def batch_work_dir(self, ingest_run_id: str, batch_index: int) -> Path: + d = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) d.mkdir(parents=True, exist_ok=True) return d - def batch_files_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - d = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) / "files" + def batch_files_dir(self, ingest_run_id: str, batch_index: int) -> Path: + d = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) / "files" d.mkdir(parents=True, exist_ok=True) return d - def hook_work_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: - d = self._layout.ingest_batch_hook_dir(ingest_run_srn, batch_index, hook_name) + def hook_work_dir(self, ingest_run_id: str, batch_index: int, hook_name: str) -> Path: + d = self._layout.ingest_batch_hook_dir(ingest_run_id, batch_index, hook_name) d.mkdir(parents=True, exist_ok=True) return d diff --git a/server/osa/infrastructure/persistence/repository/event.py b/server/osa/infrastructure/persistence/repository/event.py index 3e09df5..91eb200 100644 --- a/server/osa/infrastructure/persistence/repository/event.py +++ b/server/osa/infrastructure/persistence/repository/event.py @@ -6,8 +6,6 @@ from uuid import uuid4 from sqlalchemy import CursorResult, func, insert, or_, select, update -from sqlalchemy.dialects.postgresql import INTERVAL -from sqlalchemy.sql import literal from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.shared.error import InfrastructureError @@ -34,6 +32,7 @@ async def save_with_deliveries( self, event: Event, consumer_groups: set[str], + deliver_after: datetime | None = None, ) -> None: """Save event to append-only log and create delivery rows.""" now = datetime.now(UTC) @@ -55,6 +54,7 @@ async def save_with_deliveries( consumer_group=group, status="pending", retry_count=0, + deliver_after=deliver_after, updated_at=now, ) await self._session.execute(delivery_stmt) @@ -186,21 +186,17 @@ async def claim_delivery( """ now = datetime.now(UTC) - # Backoff formula: min(30, 5^retry_count) seconds - backoff_seconds = func.least( - literal(30), - func.power(literal(5), deliveries_table.c.retry_count), - ) - backoff_interval = func.cast(func.concat(backoff_seconds, literal(" seconds")), INTERVAL) - backoff_eligible = or_( - deliveries_table.c.retry_count == 0, - deliveries_table.c.updated_at <= func.now() - backoff_interval, + # deliver_after: NULL means immediately eligible, otherwise wait + deliver_after_eligible = or_( + deliveries_table.c.deliver_after.is_(None), + deliveries_table.c.deliver_after <= func.now(), ) # Select deliveries joined with events stmt = ( select( deliveries_table.c.id, + deliveries_table.c.retry_count, events_table.c.event_type, events_table.c.payload, ) @@ -209,7 +205,7 @@ async def claim_delivery( deliveries_table.c.consumer_group == consumer_group, deliveries_table.c.status == "pending", events_table.c.event_type.in_(event_types), - backoff_eligible, + deliver_after_eligible, ) .order_by(events_table.c.created_at.asc()) .limit(limit) @@ -234,10 +230,10 @@ async def claim_delivery( # Deserialize events and wrap in Delivery envelopes deliveries: list[Delivery] = [] for row in rows: - delivery_id, event_type, payload = row + delivery_id, retry_count, event_type, payload = row event = self._deserialize(event_type, payload) if event is not None: - deliveries.append(Delivery(id=delivery_id, event=event)) + deliveries.append(Delivery(id=delivery_id, event=event, retry_count=retry_count)) return ClaimResult(deliveries=deliveries, claimed_at=now) @@ -291,8 +287,14 @@ async def mark_failed_with_retry( delivery_id: str, error: str, max_retries: int, + deliver_after: datetime | None = None, ) -> None: - """Mark a delivery as failed with retry logic.""" + """Mark a delivery as failed with retry logic. + + Args: + deliver_after: If set, the delivery won't be eligible for claiming + until this timestamp. Used for transient resource backoff. + """ now = datetime.now(UTC) # Get current retry_count @@ -318,6 +320,7 @@ async def mark_failed_with_retry( status="failed", delivery_error=error, retry_count=new_retry_count, + deliver_after=None, updated_at=now, delivered_at=now, ) @@ -331,6 +334,7 @@ async def mark_failed_with_retry( status="pending", delivery_error=error, retry_count=new_retry_count, + deliver_after=deliver_after, claimed_at=None, updated_at=now, ) diff --git a/server/osa/infrastructure/persistence/repository/ingest.py b/server/osa/infrastructure/persistence/repository/ingest.py index 94d9beb..5d330ac 100644 --- a/server/osa/infrastructure/persistence/repository/ingest.py +++ b/server/osa/infrastructure/persistence/repository/ingest.py @@ -22,13 +22,14 @@ def __init__(self, session: AsyncSession) -> None: async def save(self, ingest_run: IngestRun) -> None: """Insert or update an ingest run.""" values = { - "srn": ingest_run.srn, + "id": ingest_run.id, "convention_srn": ingest_run.convention_srn, "status": ingest_run.status.value, "ingestion_finished": ingest_run.ingestion_finished, "batches_ingested": ingest_run.batches_ingested, "batches_completed": ingest_run.batches_completed, "published_count": ingest_run.published_count, + "batches_failed": ingest_run.batches_failed, "batch_size": ingest_run.batch_size, "record_limit": ingest_run.limit, "started_at": ingest_run.started_at, @@ -38,15 +39,15 @@ async def save(self, ingest_run: IngestRun) -> None: insert(ingest_runs_table) .values(**values) .on_conflict_do_update( - index_elements=["srn"], + index_elements=["id"], set_=values, ) ) await self._session.execute(stmt) await self._session.flush() - async def get(self, srn: str) -> IngestRun | None: - stmt = select(ingest_runs_table).where(ingest_runs_table.c.srn == srn) + async def get(self, id: str) -> IngestRun | None: + stmt = select(ingest_runs_table).where(ingest_runs_table.c.id == id) result = await self._session.execute(stmt) row = result.mappings().first() if row is None: @@ -71,7 +72,7 @@ async def get_running_for_convention(self, convention_srn: str) -> IngestRun | N return _row_to_ingest_run(dict(row)) async def increment_batches_ingested( - self, srn: str, *, set_ingestion_finished: bool = False + self, id: str, *, set_ingestion_finished: bool = False ) -> IngestRun: """Atomically increment batches_ingested.""" t = ingest_runs_table @@ -81,22 +82,40 @@ async def increment_batches_ingested( if set_ingestion_finished: values["ingestion_finished"] = True - stmt = update(t).where(t.c.srn == srn).values(**values).returning(*t.c) + stmt = update(t).where(t.c.id == id).values(**values).returning(*t.c) result = await self._session.execute(stmt) await self._session.flush() row = result.mappings().first() if row is None: from osa.domain.shared.error import NotFoundError - raise NotFoundError(f"Ingest run not found: {srn}") + raise NotFoundError(f"Ingest run not found: {id}") return _row_to_ingest_run(dict(row)) - async def increment_completed(self, srn: str, published_count: int) -> IngestRun: + async def increment_failed(self, id: str) -> IngestRun: + """Atomically increment batches_failed.""" + t = ingest_runs_table + stmt = ( + update(t) + .where(t.c.id == id) + .values(batches_failed=t.c.batches_failed + 1) + .returning(*t.c) + ) + result = await self._session.execute(stmt) + await self._session.flush() + row = result.mappings().first() + if row is None: + from osa.domain.shared.error import NotFoundError + + raise NotFoundError(f"Ingest run not found: {id}") + return _row_to_ingest_run(dict(row)) + + async def increment_completed(self, id: str, published_count: int) -> IngestRun: """Atomically increment batches_completed and published_count.""" t = ingest_runs_table stmt = ( update(t) - .where(t.c.srn == srn) + .where(t.c.id == id) .values( batches_completed=t.c.batches_completed + 1, published_count=t.c.published_count + published_count, @@ -109,19 +128,20 @@ async def increment_completed(self, srn: str, published_count: int) -> IngestRun if row is None: from osa.domain.shared.error import NotFoundError - raise NotFoundError(f"Ingest run not found: {srn}") + raise NotFoundError(f"Ingest run not found: {id}") return _row_to_ingest_run(dict(row)) def _row_to_ingest_run(row: dict) -> IngestRun: return IngestRun( - srn=row["srn"], + id=row["id"], convention_srn=row["convention_srn"], status=IngestStatus(row["status"]), ingestion_finished=row["ingestion_finished"], batches_ingested=row["batches_ingested"], batches_completed=row["batches_completed"], published_count=row["published_count"], + batches_failed=row.get("batches_failed", 0), batch_size=row["batch_size"], limit=row.get("record_limit"), started_at=row["started_at"], diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index 4f3545a..315f97c 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -120,6 +120,7 @@ Column("delivered_at", DateTime(timezone=True), nullable=True), Column("delivery_error", Text, nullable=True), Column("retry_count", Integer, nullable=False, server_default=text("0")), + Column("deliver_after", DateTime(timezone=True), nullable=True), Column("updated_at", DateTime(timezone=True), nullable=False), UniqueConstraint("event_id", "consumer_group", name="uq_delivery_event_consumer"), ) @@ -133,6 +134,13 @@ postgresql_where=text("status IN ('pending', 'claimed')"), ) +# Deferred delivery filtering +Index( + "idx_deliveries_deliver_after", + deliveries_table.c.deliver_after, + postgresql_where=text("status = 'pending'"), +) + # For joining back to events Index("idx_deliveries_event", deliveries_table.c.event_id) @@ -310,8 +318,8 @@ ingest_runs_table = Table( "ingest_runs", metadata, - Column("srn", String, primary_key=True), - Column("convention_srn", String, ForeignKey("conventions.srn"), nullable=False), + Column("id", String, primary_key=True), + Column("convention_srn", String, nullable=False), Column("status", String(32), nullable=False, server_default=text("'pending'")), Column("ingestion_finished", Boolean, nullable=False, server_default=text("false")), Column("batches_ingested", Integer, nullable=False, server_default=text("0")), @@ -319,6 +327,7 @@ Column("published_count", Integer, nullable=False, server_default=text("0")), Column("batch_size", Integer, nullable=False, server_default=text("1000")), Column("record_limit", Integer, nullable=True), + Column("batches_failed", Integer, nullable=False, server_default=text("0")), Column("started_at", DateTime(timezone=True), nullable=False), Column("completed_at", DateTime(timezone=True), nullable=True), ) diff --git a/server/osa/infrastructure/s3/ingest_storage.py b/server/osa/infrastructure/s3/ingest_storage.py index b46645c..12820a4 100644 --- a/server/osa/infrastructure/s3/ingest_storage.py +++ b/server/osa/infrastructure/s3/ingest_storage.py @@ -37,8 +37,8 @@ def _key(self, path: Path) -> str: """Convert a StorageLayout path to an S3 key.""" return relative_path(path, self._data_mount_path) - async def read_session(self, ingest_run_srn: str) -> dict[str, Any] | None: - key = self._key(self._layout.ingest_session_file(ingest_run_srn)) + async def read_session(self, ingest_run_id: str) -> dict[str, Any] | None: + key = self._key(self._layout.ingest_session_file(ingest_run_id)) try: data = await self._s3.get_object(key) return json.loads(data) @@ -47,20 +47,20 @@ async def read_session(self, ingest_run_srn: str) -> dict[str, Any] | None: return None raise - async def write_session(self, ingest_run_srn: str, session: dict[str, Any]) -> None: - key = self._key(self._layout.ingest_session_file(ingest_run_srn)) + async def write_session(self, ingest_run_id: str, session: dict[str, Any]) -> None: + key = self._key(self._layout.ingest_session_file(ingest_run_id)) await self._s3.put_object(key, json.dumps(session)) async def write_records( - self, ingest_run_srn: str, batch_index: int, records: list[dict[str, Any]] + self, ingest_run_id: str, batch_index: int, records: list[dict[str, Any]] ) -> None: - ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) key = f"{self._key(ingester_dir)}/records.jsonl" content = "".join(json.dumps(r) + "\n" for r in records) await self._s3.put_object(key, content) - async def read_records(self, ingest_run_srn: str, batch_index: int) -> list[dict[str, Any]]: - ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + async def read_records(self, ingest_run_id: str, batch_index: int) -> list[dict[str, Any]]: + ingester_dir = self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) key = f"{self._key(ingester_dir)}/records.jsonl" try: data = await self._s3.get_object(key) @@ -76,14 +76,14 @@ async def read_records(self, ingest_run_srn: str, batch_index: int) -> list[dict records.append(json.loads(line)) return records - def batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - return self._layout.ingest_batch_dir(ingest_run_srn, batch_index) + def batch_dir(self, ingest_run_id: str, batch_index: int) -> Path: + return self._layout.ingest_batch_dir(ingest_run_id, batch_index) - def batch_work_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - return self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) + def batch_work_dir(self, ingest_run_id: str, batch_index: int) -> Path: + return self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) - def batch_files_dir(self, ingest_run_srn: str, batch_index: int) -> Path: - return self._layout.ingest_batch_ingester_dir(ingest_run_srn, batch_index) / "files" + def batch_files_dir(self, ingest_run_id: str, batch_index: int) -> Path: + return self._layout.ingest_batch_ingester_dir(ingest_run_id, batch_index) / "files" - def hook_work_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: - return self._layout.ingest_batch_hook_dir(ingest_run_srn, batch_index, hook_name) + def hook_work_dir(self, ingest_run_id: str, batch_index: int, hook_name: str) -> Path: + return self._layout.ingest_batch_hook_dir(ingest_run_id, batch_index, hook_name) diff --git a/server/osa/infrastructure/storage/layout.py b/server/osa/infrastructure/storage/layout.py index a00e734..32572a1 100644 --- a/server/osa/infrastructure/storage/layout.py +++ b/server/osa/infrastructure/storage/layout.py @@ -10,11 +10,6 @@ from pathlib import Path -def _safe_srn(srn: str) -> str: - """Convert an SRN to a filesystem-safe string.""" - return srn.replace(":", "_").replace("@", "_") - - class StorageLayout: """Computes storage paths relative to a data root. @@ -27,22 +22,22 @@ def __init__(self, data_dir: Path) -> None: # ── Ingest paths ───────────────────────────────────────────────── - def ingest_run_dir(self, ingest_run_srn: str) -> Path: + def ingest_run_dir(self, ingest_run_id: str) -> Path: """Root directory for an ingest run.""" - return self._data_dir / "ingests" / _safe_srn(ingest_run_srn) + return self._data_dir / "ingests" / ingest_run_id - def ingest_batch_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + def ingest_batch_dir(self, ingest_run_id: str, batch_index: int) -> Path: """Directory for a specific batch within an ingest run.""" - return self.ingest_run_dir(ingest_run_srn) / "batches" / str(batch_index) + return self.ingest_run_dir(ingest_run_id) / "batches" / str(batch_index) - def ingest_batch_ingester_dir(self, ingest_run_srn: str, batch_index: int) -> Path: + def ingest_batch_ingester_dir(self, ingest_run_id: str, batch_index: int) -> Path: """Ingester output directory (records.jsonl, files/) for a batch.""" - return self.ingest_batch_dir(ingest_run_srn, batch_index) / "ingester" + return self.ingest_batch_dir(ingest_run_id, batch_index) / "ingester" - def ingest_batch_hook_dir(self, ingest_run_srn: str, batch_index: int, hook_name: str) -> Path: + def ingest_batch_hook_dir(self, ingest_run_id: str, batch_index: int, hook_name: str) -> Path: """Hook output directory for a batch.""" - return self.ingest_batch_dir(ingest_run_srn, batch_index) / "hooks" / hook_name + return self.ingest_batch_dir(ingest_run_id, batch_index) / "hooks" / hook_name - def ingest_session_file(self, ingest_run_srn: str) -> Path: + def ingest_session_file(self, ingest_run_id: str) -> Path: """Session state file for ingester continuation.""" - return self.ingest_run_dir(ingest_run_srn) / "session.json" + return self.ingest_run_dir(ingest_run_id) / "session.json" diff --git a/server/tests/unit/domain/feature/test_insert_record_features.py b/server/tests/unit/domain/feature/test_insert_record_features.py index dbd489e..9bb2a52 100644 --- a/server/tests/unit/domain/feature/test_insert_record_features.py +++ b/server/tests/unit/domain/feature/test_insert_record_features.py @@ -214,7 +214,7 @@ async def test_ingest_source_uses_source_fields(self): record_srn=_make_record_srn(), source=IngestSource( id="run-123-pdb-456", - ingest_run_srn="urn:osa:localhost:val:run123", + ingest_run_id="run123", upstream_source="pdb", ), metadata={"title": "Ingested"}, diff --git a/server/tests/unit/domain/ingest/test_ingest_run.py b/server/tests/unit/domain/ingest/test_ingest_run.py index 55f5395..34ab0dd 100644 --- a/server/tests/unit/domain/ingest/test_ingest_run.py +++ b/server/tests/unit/domain/ingest/test_ingest_run.py @@ -10,7 +10,7 @@ def _make_run(**overrides) -> IngestRun: defaults = { - "srn": "urn:osa:localhost:ing:test-run", + "id": "test-run-id", "convention_srn": "urn:osa:localhost:conv:test-conv@1.0.0", "status": IngestStatus.PENDING, "started_at": datetime.now(UTC), @@ -142,3 +142,74 @@ def test_batch_size_default(self) -> None: def test_custom_batch_size(self) -> None: run = _make_run(batch_size=500) assert run.batch_size == 500 + + +class TestBatchFailureAccounting: + def test_batches_failed_defaults_to_zero(self) -> None: + run = _make_run() + assert run.batches_failed == 0 + + def test_complete_with_all_batches_succeeded(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=3, + batches_failed=0, + ) + assert run.is_complete + + def test_complete_with_some_batches_failed(self) -> None: + """A run completes when all batches are accounted for, even if some failed.""" + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=2, + batches_failed=1, + ) + assert run.is_complete + + def test_complete_with_all_batches_failed(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=0, + batches_failed=3, + ) + assert run.is_complete + + def test_not_complete_when_batches_still_pending(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=1, + batches_failed=1, + ) + assert not run.is_complete + + def test_not_complete_when_ingestion_not_finished(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=False, + batches_ingested=3, + batches_completed=0, + batches_failed=3, + ) + assert not run.is_complete + + def test_check_completion_transitions_with_failures(self) -> None: + run = _make_run( + status=IngestStatus.RUNNING, + ingestion_finished=True, + batches_ingested=3, + batches_completed=2, + batches_failed=1, + ) + now = datetime.now(UTC) + completed = run.check_completion(now) + assert completed is True + assert run.status == IngestStatus.COMPLETED + assert run.completed_at == now diff --git a/server/tests/unit/domain/ingest/test_ingest_service.py b/server/tests/unit/domain/ingest/test_ingest_service.py index 4742ec6..7e1c8df 100644 --- a/server/tests/unit/domain/ingest/test_ingest_service.py +++ b/server/tests/unit/domain/ingest/test_ingest_service.py @@ -63,19 +63,25 @@ async def test_creates_pending_ingest(self) -> None: assert run.batch_size == 1000 @pytest.mark.asyncio - async def test_saves_and_emits_event(self) -> None: + async def test_saves_and_emits_events(self) -> None: service = _make_service() run = await service.start_ingest( convention_srn="urn:osa:localhost:conv:test-conv@1.0.0", ) service.ingest_repo.save.assert_called_once() - service.outbox.append.assert_called_once() - - # Verify the event is IngestStarted - event = service.outbox.append.call_args[0][0] - assert event.__class__.__name__ == "IngestStarted" - assert event.ingest_run_srn == run.srn - assert event.convention_srn == run.convention_srn + assert service.outbox.append.call_count == 2 + + # First event: IngestRunStarted (observability) + first_event = service.outbox.append.call_args_list[0][0][0] + assert first_event.__class__.__name__ == "IngestRunStarted" + assert first_event.ingest_run_id == run.id + assert first_event.convention_srn == run.convention_srn + + # Second event: NextBatchRequested (triggers first batch) + second_event = service.outbox.append.call_args_list[1][0][0] + assert second_event.__class__.__name__ == "NextBatchRequested" + assert second_event.ingest_run_id == run.id + assert second_event.convention_srn == run.convention_srn @pytest.mark.asyncio async def test_custom_batch_size(self) -> None: diff --git a/server/tests/unit/domain/ingest/test_publish_batch.py b/server/tests/unit/domain/ingest/test_publish_batch.py new file mode 100644 index 0000000..2f1479e --- /dev/null +++ b/server/tests/unit/domain/ingest/test_publish_batch.py @@ -0,0 +1,59 @@ +"""Tests for PublishBatch — exhaustion handling and completion delegation.""" + +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from osa.domain.ingest.event.events import HookBatchCompleted +from osa.domain.ingest.handler.publish_batch import PublishBatch +from osa.domain.ingest.model.ingest_run import IngestRunId +from osa.domain.shared.event import EventId + + +def _make_event( + ingest_run_id: str = "run-1", + batch_index: int = 0, +) -> HookBatchCompleted: + return HookBatchCompleted( + id=EventId(uuid4()), + ingest_run_id=IngestRunId(ingest_run_id), + batch_index=batch_index, + ) + + +def _make_handler() -> PublishBatch: + ingest_service = AsyncMock() + ingest_service.fail_batch = AsyncMock() + ingest_service.complete_batch = AsyncMock() + + return PublishBatch( + ingest_repo=AsyncMock(), + convention_service=AsyncMock(), + record_service=AsyncMock(), + feature_storage=AsyncMock(), + outbox=AsyncMock(), + ingest_storage=AsyncMock(), + ingest_service=ingest_service, + ) + + +class TestPublishBatchOnExhausted: + @pytest.mark.asyncio + async def test_on_exhausted_calls_fail_batch(self) -> None: + """When retries are exhausted, the batch must be accounted for as failed.""" + handler = _make_handler() + event = _make_event() + + await handler.on_exhausted(event) + + handler.ingest_service.fail_batch.assert_called_once_with( + IngestRunId("run-1"), + ) + + @pytest.mark.asyncio + async def test_on_exhausted_exists(self) -> None: + """PublishBatch must override on_exhausted (not rely on base class no-op).""" + from osa.domain.shared.event import EventHandler + + assert PublishBatch.on_exhausted is not EventHandler.on_exhausted diff --git a/server/tests/unit/domain/ingest/test_run_hooks.py b/server/tests/unit/domain/ingest/test_run_hooks.py new file mode 100644 index 0000000..fbea714 --- /dev/null +++ b/server/tests/unit/domain/ingest/test_run_hooks.py @@ -0,0 +1,118 @@ +"""Tests for RunHooks — OOM exhaustion should still emit HookBatchCompleted.""" + +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from osa.domain.ingest.event.events import HookBatchCompleted, IngesterBatchReady +from osa.domain.ingest.handler.run_hooks import RunHooks +from osa.domain.ingest.model.ingest_run import IngestRun, IngestRunId, IngestStatus +from osa.domain.shared.error import OOMError, PermanentError +from osa.domain.shared.event import EventId +from osa.domain.shared.model.hook import HookDefinition, OciConfig, OciLimits, TableFeatureSpec + + +def _make_hook(name: str = "pockets") -> HookDefinition: + return HookDefinition( + name=name, + runtime=OciConfig( + image="ghcr.io/test/pockets:v1", + digest="sha256:abc123", + limits=OciLimits(memory="1g"), + ), + feature=TableFeatureSpec(cardinality="one", columns=[]), + ) + + +def _make_event( + ingest_run_id: str = "run-1", + batch_index: int = 0, +) -> IngesterBatchReady: + return IngesterBatchReady( + id=EventId(uuid4()), + ingest_run_id=IngestRunId(ingest_run_id), + batch_index=batch_index, + has_more=False, + ) + + +def _make_convention(): + conv = AsyncMock() + conv.hooks = [_make_hook()] + return conv + + +def _make_handler(*, hook_service_side_effect=None) -> RunHooks: + ingest_repo = AsyncMock() + ingest_repo.get.return_value = IngestRun( + id=IngestRunId("run-1"), + convention_srn="urn:osa:localhost:conv:test@1.0.0", + status=IngestStatus.RUNNING, + batch_size=100, + started_at=__import__("datetime").datetime.now(__import__("datetime").UTC), + ) + + convention_service = AsyncMock() + convention_service.get_convention.return_value = _make_convention() + + ingest_storage = AsyncMock() + ingest_storage.read_records.return_value = [ + {"source_id": "rec-1", "metadata": {}, "files": []}, + ] + ingest_storage.batch_files_dir.return_value = __import__("pathlib").Path("/tmp/files") + ingest_storage.hook_work_dir.return_value = __import__("pathlib").Path("/tmp/work") + + hook_service = AsyncMock() + if hook_service_side_effect: + hook_service.run_hooks_for_batch.side_effect = hook_service_side_effect + + return RunHooks( + ingest_repo=ingest_repo, + ingest_service=AsyncMock(), + convention_service=convention_service, + hook_service=hook_service, + outbox=AsyncMock(), + ingest_storage=ingest_storage, + ) + + +class TestRunHooksOOMExhaustion: + @pytest.mark.asyncio + async def test_oom_exhaustion_emits_hook_batch_completed(self) -> None: + """OOM exhaustion should still emit HookBatchCompleted so passed records get published.""" + handler = _make_handler(hook_service_side_effect=OOMError("OOM after 3 retries")) + event = _make_event() + + await handler.handle(event) + + # HookBatchCompleted should be emitted (not swallowed) + emitted_events = [call[0][0] for call in handler.outbox.append.call_args_list] + assert any(isinstance(e, HookBatchCompleted) for e in emitted_events), ( + "HookBatchCompleted should be emitted on OOM exhaustion " + "so PublishBatch can publish records that passed" + ) + + @pytest.mark.asyncio + async def test_oom_exhaustion_does_not_fail_batch(self) -> None: + """OOM exhaustion should not call _fail_batch — the batch has partial results.""" + handler = _make_handler(hook_service_side_effect=OOMError("OOM after 3 retries")) + event = _make_event() + + await handler.handle(event) + + handler.ingest_service.fail_batch.assert_not_called() + + @pytest.mark.asyncio + async def test_permanent_error_still_fails_batch(self) -> None: + """Non-OOM PermanentError should still fail the batch (no partial results).""" + error = PermanentError("image pull failed") + handler = _make_handler(hook_service_side_effect=error) + event = _make_event() + + await handler.handle(event) + + handler.ingest_service.fail_batch.assert_called_once() + # HookBatchCompleted should NOT be emitted + emitted_events = [call[0][0] for call in handler.outbox.append.call_args_list] + assert not any(isinstance(e, HookBatchCompleted) for e in emitted_events) diff --git a/server/tests/unit/domain/record/test_record_service.py b/server/tests/unit/domain/record/test_record_service.py index 9dff7c1..e62c2cd 100644 --- a/server/tests/unit/domain/record/test_record_service.py +++ b/server/tests/unit/domain/record/test_record_service.py @@ -138,7 +138,7 @@ async def test_publish_with_ingest_source( draft = RecordDraft( source=IngestSource( id="run-123-pdb-456", - ingest_run_srn="urn:osa:localhost:val:run123", + ingest_run_id="run123", upstream_source="pdb", ), metadata={"title": "Ingested Protein"}, diff --git a/server/tests/unit/domain/shared/test_error_types.py b/server/tests/unit/domain/shared/test_error_types.py new file mode 100644 index 0000000..05aca9f --- /dev/null +++ b/server/tests/unit/domain/shared/test_error_types.py @@ -0,0 +1,45 @@ +"""Tests for runner-specific error types.""" + +from osa.domain.shared.error import ( + InfrastructureError, + OOMError, + PermanentError, + TransientError, +) + + +class TestOOMError: + def test_is_infrastructure_error(self): + err = OOMError("Hook killed by OOM") + assert isinstance(err, InfrastructureError) + + def test_is_permanent_error(self): + err = OOMError("Hook killed by OOM") + assert isinstance(err, PermanentError) + + def test_message_and_code(self): + err = OOMError("Hook killed by OOM") + assert err.message == "Hook killed by OOM" + assert err.code == "OOMError" + + +class TestTransientError: + def test_is_infrastructure_error(self): + err = TransientError("Pod scheduling timeout") + assert isinstance(err, InfrastructureError) + + def test_message_and_code(self): + err = TransientError("Pod scheduling timeout after 120s") + assert err.message == "Pod scheduling timeout after 120s" + assert err.code == "TransientError" + + +class TestPermanentError: + def test_is_infrastructure_error(self): + err = PermanentError("Image pull failed") + assert isinstance(err, InfrastructureError) + + def test_message_and_code(self): + err = PermanentError("Image pull failed: ImagePullBackOff") + assert err.message == "Image pull failed: ImagePullBackOff" + assert err.code == "PermanentError" diff --git a/server/tests/unit/domain/shared/test_outbox_claim.py b/server/tests/unit/domain/shared/test_outbox_claim.py index 85b90c7..7326bb2 100644 --- a/server/tests/unit/domain/shared/test_outbox_claim.py +++ b/server/tests/unit/domain/shared/test_outbox_claim.py @@ -147,5 +147,5 @@ async def test_mark_failed_with_retry_delegates_to_repo( await outbox.mark_failed_with_retry(delivery_id, error, max_retries=3) mock_repo.mark_failed_with_retry.assert_called_once_with( - delivery_id, error=error, max_retries=3 + delivery_id, error=error, max_retries=3, deliver_after=None ) diff --git a/server/tests/unit/domain/shared/test_outbox_consumer_groups.py b/server/tests/unit/domain/shared/test_outbox_consumer_groups.py index 9964494..322d76d 100644 --- a/server/tests/unit/domain/shared/test_outbox_consumer_groups.py +++ b/server/tests/unit/domain/shared/test_outbox_consumer_groups.py @@ -57,7 +57,7 @@ async def test_append_creates_deliveries_for_subscribed_groups( await outbox.append(event) mock_repo.save_with_deliveries.assert_called_once_with( - event, consumer_groups={"HandlerA", "HandlerB"} + event, consumer_groups={"HandlerA", "HandlerB"}, deliver_after=None ) async def test_append_audit_only_event_creates_zero_deliveries( @@ -73,7 +73,9 @@ class AuditEvent(Event): await outbox.append(event) - mock_repo.save_with_deliveries.assert_called_once_with(event, consumer_groups=set()) + mock_repo.save_with_deliveries.assert_called_once_with( + event, consumer_groups=set(), deliver_after=None + ) class TestOutboxClaimByConsumerGroup: @@ -186,7 +188,7 @@ async def test_mark_failed_with_retry_uses_delivery_id( await outbox.mark_failed_with_retry(delivery_id, error="Timeout", max_retries=3) mock_repo.mark_failed_with_retry.assert_called_once_with( - delivery_id, error="Timeout", max_retries=3 + delivery_id, error="Timeout", max_retries=3, deliver_after=None ) diff --git a/server/tests/unit/domain/shared/test_record_source.py b/server/tests/unit/domain/shared/test_record_source.py index c1fe20d..64497e1 100644 --- a/server/tests/unit/domain/shared/test_record_source.py +++ b/server/tests/unit/domain/shared/test_record_source.py @@ -31,23 +31,23 @@ class TestIngestSource: def test_type_is_ingest(self): src = IngestSource( id="run-123-source-456", - ingest_run_srn="urn:osa:localhost:val:run123", + ingest_run_id="run123", upstream_source="pdb", ) assert src.type == "ingest" - def test_requires_ingest_run_srn(self): + def test_requires_ingest_run_id(self): with pytest.raises(ValidationError): IngestSource(id="run-123", upstream_source="pdb") def test_requires_upstream_source(self): with pytest.raises(ValidationError): - IngestSource(id="run-123", ingest_run_srn="urn:osa:localhost:val:run123") + IngestSource(id="run-123", ingest_run_id="run123") def test_serialization_roundtrip(self): src = IngestSource( id="run-123-source-456", - ingest_run_srn="urn:osa:localhost:val:run123", + ingest_run_id="run123", upstream_source="pdb", ) data = src.model_dump() @@ -65,7 +65,7 @@ def test_deserializes_ingest(self): data = { "type": "ingest", "id": "run-123", - "ingest_run_srn": "urn:osa:localhost:val:run1", + "ingest_run_id": "urn:osa:localhost:val:run1", "upstream_source": "geo", } adapter = TypeAdapter(RecordSource) @@ -83,7 +83,7 @@ def test_json_roundtrip(self): adapter = TypeAdapter(RecordSource) src = IngestSource( id="run-1", - ingest_run_srn="urn:osa:localhost:val:run1", + ingest_run_id="run1", upstream_source="pdb", ) json_str = adapter.dump_json(src) diff --git a/server/tests/unit/domain/validation/test_hook_result.py b/server/tests/unit/domain/validation/test_hook_result.py index 8465ff5..d245e74 100644 --- a/server/tests/unit/domain/validation/test_hook_result.py +++ b/server/tests/unit/domain/validation/test_hook_result.py @@ -6,7 +6,14 @@ def test_hook_status_values(): assert HookStatus.PASSED == "passed" assert HookStatus.REJECTED == "rejected" - assert HookStatus.FAILED == "failed" + + +def test_hook_status_only_business_outcomes(): + """HookStatus should only carry business outcomes, not failure modes.""" + from osa.domain.validation.model.hook_result import HookStatus + + members = set(HookStatus) + assert members == {HookStatus.PASSED, HookStatus.REJECTED} def test_progress_entry_full(): @@ -61,19 +68,6 @@ def test_hook_result_rejected(): assert result.rejection_reason == "Missing coordinates" -def test_hook_result_failed(): - from osa.domain.validation.model.hook_result import HookResult, HookStatus - - result = HookResult( - hook_name="detect_pockets", - status=HookStatus.FAILED, - error_message="OOM killed", - duration_seconds=300.0, - ) - assert result.status == HookStatus.FAILED - assert result.error_message == "OOM killed" - - def test_hook_result_with_progress(): from osa.domain.validation.model.hook_result import ( HookResult, @@ -105,44 +99,6 @@ def test_hook_result_default_progress_empty(): assert result.progress == [] -def test_hook_status_oom_value(): - from osa.domain.validation.model.hook_result import HookStatus - - assert HookStatus.OOM == "oom" - assert HookStatus.OOM.value == "oom" - - -def test_hook_result_oom_killed_true(): - from osa.domain.validation.model.hook_result import HookResult, HookStatus - - result = HookResult( - hook_name="detect_pockets", - status=HookStatus.OOM, - error_message="Hook killed by OOM (limit: 1g)", - duration_seconds=30.0, - ) - assert result.oom_killed is True - - -def test_hook_result_oom_killed_false(): - from osa.domain.validation.model.hook_result import HookResult, HookStatus - - result = HookResult( - hook_name="detect_pockets", - status=HookStatus.FAILED, - error_message="Some other error", - duration_seconds=10.0, - ) - assert result.oom_killed is False - - passed = HookResult( - hook_name="detect_pockets", - status=HookStatus.PASSED, - duration_seconds=5.0, - ) - assert passed.oom_killed is False - - def test_hook_result_serialization_roundtrip(): from osa.domain.validation.model.hook_result import ( HookResult, diff --git a/server/tests/unit/domain/validation/test_hook_runner.py b/server/tests/unit/domain/validation/test_hook_runner.py index 5b7052c..2536ff0 100644 --- a/server/tests/unit/domain/validation/test_hook_runner.py +++ b/server/tests/unit/domain/validation/test_hook_runner.py @@ -89,6 +89,9 @@ async def run( duration_seconds=0.1, ) + async def capture_logs(self, run_id: str) -> str: + return "" + assert isinstance(FakeRunner(), HookRunner) def test_incomplete_class_does_not_satisfy_protocol(self): @@ -106,5 +109,8 @@ class LaxRunner: async def run(self, *args, **kwargs): pass + async def capture_logs(self, *args, **kwargs): + pass + # runtime_checkable only checks method names exist, not signatures assert isinstance(LaxRunner(), HookRunner) diff --git a/server/tests/unit/domain/validation/test_hook_service.py b/server/tests/unit/domain/validation/test_hook_service.py index a24b741..7b1b4b9 100644 --- a/server/tests/unit/domain/validation/test_hook_service.py +++ b/server/tests/unit/domain/validation/test_hook_service.py @@ -19,6 +19,7 @@ OutcomeStatus, ) from osa.domain.validation.model.hook_input import HookRecord +from osa.domain.shared.error import OOMError from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.port.hook_runner import HookInputs @@ -50,22 +51,8 @@ def _passed_result(hook_name: str = "detect_pockets", duration: float = 5.0) -> return HookResult(hook_name=hook_name, status=HookStatus.PASSED, duration_seconds=duration) -def _oom_result(hook_name: str = "detect_pockets", duration: float = 30.0) -> HookResult: - return HookResult( - hook_name=hook_name, - status=HookStatus.OOM, - error_message="Hook killed by OOM", - duration_seconds=duration, - ) - - -def _failed_result(hook_name: str = "detect_pockets", duration: float = 10.0) -> HookResult: - return HookResult( - hook_name=hook_name, - status=HookStatus.FAILED, - error_message="Some error", - duration_seconds=duration, - ) +def _oom_error() -> OOMError: + return OOMError("Hook killed by OOM") class FakeHookStorage: @@ -164,7 +151,7 @@ async def mock_run(h, inputs, wd): features_file.write_text( json.dumps({"id": records[0].id, "features": [{"score": 0.5}]}) + "\n" ) - return _oom_result() + raise _oom_error() else: # Second call: succeed with remaining features_file = output_dir / "features.jsonl" @@ -188,7 +175,7 @@ async def mock_run(h, inputs, wd): class TestHookServiceOOMExhaustion: - """T017: OOM exhaustion marks remaining records as errored.""" + """T017: OOM exhaustion marks remaining records as errored and re-raises.""" @pytest.mark.asyncio async def test_oom_exhaustion_marks_errored(self, tmp_path: Path): @@ -202,13 +189,13 @@ async def test_oom_exhaustion_marks_errored(self, tmp_path: Path): output_dir.mkdir(parents=True) runner = AsyncMock() - runner.run.return_value = _oom_result() + runner.run.side_effect = _oom_error() storage = FakeHookStorage() service = HookService(hook_runner=runner, hook_storage=storage) - result = await service.run_hook(hook, _inputs(records), work_dir) + with pytest.raises(OOMError): + await service.run_hook(hook, _inputs(records), work_dir) - assert result.status == HookStatus.OOM # Should have retried MAX_OOM_RETRIES times assert runner.run.call_count == 4 # 1 initial + 3 retries @@ -221,10 +208,11 @@ async def test_oom_exhaustion_marks_errored(self, tmp_path: Path): class TestHookServiceNonOOMFailure: - """T018: Non-OOM failure does NOT trigger retry.""" + """T018: Non-OOM failure propagates without retry.""" @pytest.mark.asyncio async def test_non_oom_failure_no_retry(self, tmp_path: Path): + from osa.domain.shared.error import PermanentError from osa.domain.validation.service.hook import HookService hook = _make_hook() @@ -234,13 +222,13 @@ async def test_non_oom_failure_no_retry(self, tmp_path: Path): (work_dir / "output").mkdir(parents=True) runner = AsyncMock() - runner.run.return_value = _failed_result() + runner.run.side_effect = PermanentError("Hook exited with code 1") storage = FakeHookStorage() service = HookService(hook_runner=runner, hook_storage=storage) - result = await service.run_hook(hook, _inputs(records), work_dir) + with pytest.raises(PermanentError): + await service.run_hook(hook, _inputs(records), work_dir) - assert result.status == HookStatus.FAILED runner.run.assert_called_once() @@ -338,7 +326,7 @@ async def side_effect(h, inputs, wd): if h.name == "hook_one": return _passed_result(hook_name="hook_one") else: - return _oom_result(hook_name="hook_two") + raise _oom_error() runner.run.side_effect = side_effect @@ -348,13 +336,13 @@ async def side_effect(h, inputs, wd): r1 = await service.run_hook(hook1, _inputs(records), work_dir1) assert r1.status == HookStatus.PASSED - # Run hook 2 — should OOM and exhaust retries - r2 = await service.run_hook(hook2, _inputs(records), work_dir2) - assert r2.status == HookStatus.OOM + # Run hook 2 — should OOM and exhaust retries, then raise + with pytest.raises(OOMError): + await service.run_hook(hook2, _inputs(records), work_dir2) # Hook 1 was called once, hook 2 was called 4 times (1 + 3 retries) hook1_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_one"] - hook2_calls = [c for c in runner.run.call_args_list if c[0][0].name == "hook_two"] + hook2_calls = [c for c in runner.run.call_args_list if c[0][0].name != "hook_one"] assert len(hook1_calls) == 1 assert len(hook2_calls) == 4 diff --git a/server/tests/unit/domain/validation/test_validation_service.py b/server/tests/unit/domain/validation/test_validation_service.py index 2589b7d..2b73881 100644 --- a/server/tests/unit/domain/validation/test_validation_service.py +++ b/server/tests/unit/domain/validation/test_validation_service.py @@ -13,6 +13,7 @@ ) from osa.domain.shared.model.srn import DepositionSRN, Domain from osa.domain.validation.model import RunStatus +from osa.domain.shared.error import OOMError, PermanentError from osa.domain.validation.model.hook_result import HookResult, HookStatus from osa.domain.validation.model.hook_input import HookRecord from osa.domain.validation.port.hook_runner import HookInputs @@ -134,9 +135,7 @@ async def test_hook_rejected_halts_pipeline(self): @pytest.mark.asyncio async def test_hook_failed_halts_pipeline(self): hook_runner = AsyncMock() - hook_runner.run.return_value = _make_hook_result( - status=HookStatus.FAILED, - ) + hook_runner.run.side_effect = PermanentError("Hook exited with code 1") service = _make_service(hook_runner=hook_runner) run = await service.create_run(inputs=_make_inputs()) @@ -204,7 +203,7 @@ async def run_hook(hook, inputs, output_dir): async def test_validation_service_halts_on_oom(self): """REGRESSION: OOM with exhausted retries should halt pipeline as FAILED.""" hook_runner = AsyncMock() - hook_runner.run.return_value = _make_hook_result(status=HookStatus.OOM) + hook_runner.run.side_effect = OOMError("Hook killed by OOM") service = _make_service(hook_runner=hook_runner) run = await service.create_run(inputs=_make_inputs()) @@ -226,12 +225,7 @@ async def run_hook(hook, inputs, output_dir): nonlocal call_count call_count += 1 if call_count == 1: - return HookResult( - hook_name=hook.name, - status=HookStatus.OOM, - error_message="OOM", - duration_seconds=30.0, - ) + raise OOMError("Hook killed by OOM") return HookResult( hook_name=hook.name, status=HookStatus.PASSED, diff --git a/server/tests/unit/infrastructure/event/test_worker.py b/server/tests/unit/infrastructure/event/test_worker.py index 4723ab0..84aaac3 100644 --- a/server/tests/unit/infrastructure/event/test_worker.py +++ b/server/tests/unit/infrastructure/event/test_worker.py @@ -358,9 +358,12 @@ async def test_worker_handles_handler_error(self): # Act - Run one poll cycle await worker._poll_once() - # Assert - Event should be marked as failed using delivery_id - outbox.mark_failed_with_retry.assert_called_once_with( - delivery_id, "Processing failed", max_retries=3 - ) + # Assert - Event should be marked as failed using delivery_id with backoff + outbox.mark_failed_with_retry.assert_called_once() + call_args = outbox.mark_failed_with_retry.call_args + assert call_args[0][0] == delivery_id + assert call_args[0][1] == "Processing failed" + assert call_args[1]["max_retries"] == 3 + assert call_args[1]["deliver_after"] is not None assert worker.state.failed_count == 1 assert worker.state.error is not None diff --git a/server/tests/unit/infrastructure/event/test_worker_exhaustion.py b/server/tests/unit/infrastructure/event/test_worker_exhaustion.py new file mode 100644 index 0000000..ef10915 --- /dev/null +++ b/server/tests/unit/infrastructure/event/test_worker_exhaustion.py @@ -0,0 +1,72 @@ +"""Tests for worker on_exhausted error safety. + +Verifies that mark_failed is always called even when on_exhausted raises, +by testing the Worker._poll_once flow end-to-end with a handler that +raises PermanentError from handle() and RuntimeError from on_exhausted(). +""" + +from datetime import UTC, datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from osa.domain.ingest.event.events import NextBatchRequested +from osa.domain.ingest.handler.run_ingester import RunIngester +from osa.domain.ingest.model.ingest_run import IngestRunId +from osa.domain.shared.event import EventId +from osa.infrastructure.event.worker import Worker + + +class TestWorkerOnExhaustedErrorSafety: + @pytest.mark.asyncio + async def test_mark_failed_runs_even_if_on_exhausted_raises(self) -> None: + """If on_exhausted throws, mark_failed must still run.""" + worker = Worker(RunIngester) + + # Mock handler whose on_exhausted raises + handler = AsyncMock(spec=RunIngester) + handler.handle = AsyncMock(side_effect=Exception("something broke")) + handler.on_exhausted = AsyncMock(side_effect=RuntimeError("DB down during on_exhausted")) + + outbox = AsyncMock() + + # Create a delivery that has exhausted retries + event = NextBatchRequested( + id=EventId(uuid4()), + ingest_run_id=IngestRunId("run-1"), + convention_srn="urn:osa:localhost:conv:test@1.0.0", + batch_size=100, + ) + delivery = MagicMock() + delivery.id = "delivery-1" + delivery.event = event + delivery.retry_count = 100 # exceeds __max_retries__ + + claim_result = MagicMock() + claim_result.deliveries = [delivery] + claim_result.events = [event] + claim_result.claimed_at = datetime.now(UTC) + outbox.claim.return_value = claim_result + + # Wire up the DI scope mock + async def mock_get(t: type) -> Any: + if t is RunIngester: + return handler + return outbox + + scope = AsyncMock() + scope.get = mock_get + + container = MagicMock() + container.return_value.__aenter__ = AsyncMock(return_value=scope) + container.return_value.__aexit__ = AsyncMock(return_value=False) + + worker.set_container(container) + await worker._poll_once() + + # on_exhausted was called (and raised) + handler.on_exhausted.assert_called_once() + # mark_failed was STILL called despite the exception + outbox.mark_failed.assert_called_once_with("delivery-1", "something broke") diff --git a/server/tests/unit/infrastructure/k8s/test_classify_api_error.py b/server/tests/unit/infrastructure/k8s/test_classify_api_error.py index c9d3096..64beba5 100644 --- a/server/tests/unit/infrastructure/k8s/test_classify_api_error.py +++ b/server/tests/unit/infrastructure/k8s/test_classify_api_error.py @@ -1,6 +1,6 @@ """Tests for K8s API error classification.""" -from osa.domain.shared.error import ConfigurationError, InfrastructureError +from osa.domain.shared.error import PermanentError, TransientError from osa.infrastructure.k8s.errors import classify_api_error @@ -14,33 +14,33 @@ def __init__(self, status: int, reason: str = ""): class TestClassifyApiError: - def test_403_returns_configuration_error(self): + def test_403_returns_permanent_runner_error(self): exc = _FakeApiException(403, "Forbidden") result = classify_api_error(exc) - assert isinstance(result, ConfigurationError) + assert isinstance(result, PermanentError) assert "RBAC" in result.message or "permission" in result.message.lower() - def test_404_returns_configuration_error(self): + def test_404_returns_permanent_runner_error(self): exc = _FakeApiException(404, "Not Found") result = classify_api_error(exc) - assert isinstance(result, ConfigurationError) + assert isinstance(result, PermanentError) - def test_500_returns_infrastructure_error(self): + def test_500_returns_transient_resource_error(self): exc = _FakeApiException(500, "Internal Server Error") result = classify_api_error(exc) - assert isinstance(result, InfrastructureError) + assert isinstance(result, TransientError) - def test_503_returns_infrastructure_error(self): + def test_503_returns_transient_resource_error(self): exc = _FakeApiException(503, "Service Unavailable") result = classify_api_error(exc) - assert isinstance(result, InfrastructureError) + assert isinstance(result, TransientError) - def test_409_returns_infrastructure_error(self): + def test_409_returns_transient_resource_error(self): exc = _FakeApiException(409, "Conflict") result = classify_api_error(exc) - assert isinstance(result, InfrastructureError) + assert isinstance(result, TransientError) - def test_unknown_status_returns_infrastructure_error(self): + def test_unknown_status_returns_transient_resource_error(self): exc = _FakeApiException(429, "Too Many Requests") result = classify_api_error(exc) - assert isinstance(result, InfrastructureError) + assert isinstance(result, TransientError) diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py index 9e63c87..ead5ad1 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_hook_runner.py @@ -6,7 +6,11 @@ import pytest from osa.config import K8sConfig -from osa.domain.shared.error import InfrastructureError +from osa.domain.shared.error import ( + OOMError, + PermanentError, + TransientError, +) from osa.domain.shared.model.hook import ( ColumnDef, HookDefinition, @@ -70,7 +74,10 @@ def _make_s3_mock() -> AsyncMock: def _make_runner(config: K8sConfig | None = None) -> K8sHookRunner: api_client = MagicMock() s3 = _make_s3_mock() - return K8sHookRunner(api_client=api_client, config=config or _make_config(), s3=s3) + runner = K8sHookRunner(api_client=api_client, config=config or _make_config(), s3=s3) + runner._batch_api = AsyncMock() + runner._core_api = AsyncMock() + return runner # --------------------------------------------------------------------------- @@ -208,7 +215,8 @@ def test_labels(self): labels = spec.spec.template.metadata.labels assert labels["osa.io/role"] == "hook" assert labels["osa.io/hook"] == "validate_dna" - assert labels["osa.io/run-id"] == "run-abc123" + assert labels["osa.io/ingest-run-id"] == "run-abc123" + assert labels["osa.io/ingest-run-batch"] == "0" def test_human_readable_job_name(self): runner = _make_runner() @@ -331,6 +339,7 @@ class TestSchedulingWatch: async def test_pod_leaves_pending_quickly(self): runner = _make_runner() core_api = AsyncMock() + runner._core_api = core_api # Pod transitions from Pending to Running pod = MagicMock() @@ -340,12 +349,13 @@ async def test_pod_leaves_pending_quickly(self): pod_list.items = [pod] core_api.list_namespaced_pod.return_value = pod_list - await runner._wait_for_scheduling(core_api, "test-job", "osa") + await runner._wait_for_scheduling("test-job", "osa") @pytest.mark.asyncio async def test_pod_stuck_scheduling_timeout(self): runner = _make_runner() core_api = AsyncMock() + runner._core_api = core_api # Pod stays in Pending pod = MagicMock() @@ -355,15 +365,16 @@ async def test_pod_stuck_scheduling_timeout(self): pod_list.items = [pod] core_api.list_namespaced_pod.return_value = pod_list - with pytest.raises(InfrastructureError, match="scheduling"): + with pytest.raises(TransientError, match="scheduling"): await runner._wait_for_scheduling( - core_api, "test-job", "osa", timeout_seconds=0.1, poll_interval=0.05 + "test-job", "osa", timeout_seconds=0.1, poll_interval=0.05 ) @pytest.mark.asyncio async def test_image_pull_backoff_fails_fast(self): runner = _make_runner() core_api = AsyncMock() + runner._core_api = core_api pod = MagicMock() pod.status.phase = "Pending" @@ -375,13 +386,14 @@ async def test_image_pull_backoff_fails_fast(self): pod_list.items = [pod] core_api.list_namespaced_pod.return_value = pod_list - with pytest.raises(InfrastructureError, match="[Ii]mage pull"): - await runner._wait_for_scheduling(core_api, "test-job", "osa") + with pytest.raises(PermanentError, match="[Ii]mage pull"): + await runner._wait_for_scheduling("test-job", "osa") @pytest.mark.asyncio async def test_err_image_pull_fails_fast(self): runner = _make_runner() core_api = AsyncMock() + runner._core_api = core_api pod = MagicMock() pod.status.phase = "Pending" @@ -393,13 +405,14 @@ async def test_err_image_pull_fails_fast(self): pod_list.items = [pod] core_api.list_namespaced_pod.return_value = pod_list - with pytest.raises(InfrastructureError, match="[Ii]mage pull"): - await runner._wait_for_scheduling(core_api, "test-job", "osa") + with pytest.raises(PermanentError, match="[Ii]mage pull"): + await runner._wait_for_scheduling("test-job", "osa") @pytest.mark.asyncio async def test_pod_evicted(self): runner = _make_runner() core_api = AsyncMock() + runner._core_api = core_api pod = MagicMock() pod.status.phase = "Failed" @@ -409,8 +422,8 @@ async def test_pod_evicted(self): pod_list.items = [pod] core_api.list_namespaced_pod.return_value = pod_list - with pytest.raises(InfrastructureError, match="[Ee]vict"): - await runner._wait_for_scheduling(core_api, "test-job", "osa") + with pytest.raises(TransientError, match="[Ee]vict"): + await runner._wait_for_scheduling("test-job", "osa") # --------------------------------------------------------------------------- @@ -427,6 +440,8 @@ async def test_successful_run(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api # No existing jobs (orphan check) job_list = MagicMock() @@ -462,13 +477,7 @@ async def test_successful_run(self, tmp_path: Path): ) inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) - result = await runner._run_job( - batch_api, - core_api, - hook, - inputs, - work_dir, - ) + result = await runner._run_job(hook, inputs, work_dir) assert result.status == HookStatus.PASSED assert len(result.progress) == 1 @@ -482,6 +491,8 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -512,19 +523,12 @@ async def test_timeout_deadline_exceeded(self, tmp_path: Path): work_dir.mkdir(parents=True) inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) - result = await runner._run_job( - batch_api, - core_api, - hook, - inputs, - work_dir, - ) - - assert result.status == HookStatus.FAILED - assert ( - "timed out" in result.error_message.lower() - or "deadline" in result.error_message.lower() - ) + with pytest.raises(TransientError, match="[Tt]imed out|[Dd]eadline"): + await runner._run_job( + hook, + inputs, + work_dir, + ) batch_api.delete_namespaced_job.assert_called_once() @pytest.mark.asyncio @@ -534,6 +538,8 @@ async def test_oom_exit_137(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -577,16 +583,12 @@ async def test_oom_exit_137(self, tmp_path: Path): work_dir.mkdir(parents=True) inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) - result = await runner._run_job( - batch_api, - core_api, - hook, - inputs, - work_dir, - ) - - assert result.status == HookStatus.OOM - assert "oom" in result.error_message.lower() + with pytest.raises(OOMError, match="[Oo][Oo][Mm]"): + await runner._run_job( + hook, + inputs, + work_dir, + ) @pytest.mark.asyncio async def test_nonzero_exit(self, tmp_path: Path): @@ -595,6 +597,8 @@ async def test_nonzero_exit(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -636,16 +640,12 @@ async def test_nonzero_exit(self, tmp_path: Path): work_dir.mkdir(parents=True) inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) - result = await runner._run_job( - batch_api, - core_api, - hook, - inputs, - work_dir, - ) - - assert result.status == HookStatus.FAILED - assert "exit" in result.error_message.lower() + with pytest.raises(PermanentError, match="[Ee]xit"): + await runner._run_job( + hook, + inputs, + work_dir, + ) @pytest.mark.asyncio async def test_orphan_running_job_attaches(self, tmp_path: Path): @@ -655,6 +655,8 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api # Existing active job existing_job = MagicMock() @@ -691,8 +693,6 @@ async def test_orphan_running_job_attaches(self, tmp_path: Path): inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( - batch_api, - core_api, hook, inputs, work_dir, @@ -710,6 +710,8 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api existing_job = MagicMock() existing_job.metadata.name = "osa-hook-existing" @@ -727,8 +729,6 @@ async def test_orphan_completed_job_reads_output(self, tmp_path: Path): inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( - batch_api, - core_api, hook, inputs, work_dir, @@ -745,6 +745,8 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api existing_job = MagicMock() existing_job.metadata.name = "osa-hook-existing" @@ -782,8 +784,6 @@ async def test_orphan_failed_job_creates_new(self, tmp_path: Path): inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( - batch_api, - core_api, hook, inputs, work_dir, @@ -807,7 +807,7 @@ class FakeNotFound(Exception): batch_api.delete_namespaced_job.side_effect = FakeNotFound() # Should not raise - await runner._cleanup_job(batch_api, "test-job", "osa") + await runner._cleanup_job("test-job", "osa") @pytest.mark.asyncio async def test_rejection_via_progress(self, tmp_path: Path): @@ -817,6 +817,8 @@ async def test_rejection_via_progress(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -847,8 +849,6 @@ async def test_rejection_via_progress(self, tmp_path: Path): inputs = HookInputs(records=[HookRecord(id="test", metadata={})], run_id=_RUN_ID) result = await runner._run_job( - batch_api, - core_api, hook, inputs, work_dir, @@ -869,13 +869,13 @@ class TestRunIdFromInputs: @pytest.mark.asyncio async def test_run_uses_run_id_from_inputs(self, tmp_path: Path): """The run_id in Job labels comes from inputs, not the work_dir path.""" - from unittest.mock import patch - config = _make_config(data_mount_path=str(tmp_path)) runner = K8sHookRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api # No existing jobs job_list = MagicMock() @@ -909,14 +909,11 @@ async def test_run_uses_run_id_from_inputs(self, tmp_path: Path): run_id="my-real-run-id", ) - with ( - patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), - patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), - ): - await runner.run(hook, inputs, work_dir) + await runner.run(hook, inputs, work_dir) # Verify the Job was created with the run_id from inputs call_args = batch_api.create_namespaced_job.call_args spec = call_args[0][1] # positional arg: (namespace, spec) labels = spec.metadata.labels - assert labels["osa.io/run-id"] == "my-real-run-id" + assert labels["osa.io/ingest-run-id"] == "my-real-run-id" + assert labels["osa.io/ingest-run-batch"] == "0" diff --git a/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py index 398905e..0341814 100644 --- a/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py +++ b/server/tests/unit/infrastructure/k8s/test_k8s_ingester_runner.py @@ -7,7 +7,7 @@ import pytest from osa.config import K8sConfig -from osa.domain.shared.error import ExternalServiceError +from osa.domain.shared.error import OOMError, TransientError from osa.domain.shared.model.source import IngesterDefinition, IngesterLimits from osa.domain.shared.model.srn import ConventionSRN from osa.domain.shared.port.ingester_runner import IngesterInputs @@ -59,6 +59,96 @@ def _make_runner(config: K8sConfig | None = None) -> K8sIngesterRunner: return K8sIngesterRunner(api_client=api_client, config=config or _make_config(), s3=s3) +# --------------------------------------------------------------------------- +# Backpressure capacity check +# --------------------------------------------------------------------------- + + +def _make_pod(*, phase: str = "Pending", conditions: list | None = None) -> MagicMock: + pod = MagicMock() + pod.status.phase = phase + pod.status.conditions = conditions + return pod + + +def _make_condition(*, type_: str, reason: str, status: str = "False") -> MagicMock: + cond = MagicMock() + cond.type = type_ + cond.reason = reason + cond.status = status + return cond + + +class TestHasCapacity: + @pytest.mark.asyncio + async def test_no_pending_pods_returns_true(self): + runner = _make_runner() + runner._core_api = AsyncMock() + pod_list = MagicMock() + pod_list.items = [] + runner._core_api.list_namespaced_pod.return_value = pod_list + + assert await runner.has_capacity() is True + + @pytest.mark.asyncio + async def test_pending_but_schedulable_returns_true(self): + """Pods in Pending that are actively scheduling (no Unschedulable condition) + should NOT trigger backpressure — they'll be Running in seconds.""" + runner = _make_runner() + runner._core_api = AsyncMock() + + # Pod is Pending with PodScheduled=True (normal startup) + pod = _make_pod( + phase="Pending", + conditions=[_make_condition(type_="PodScheduled", reason="", status="True")], + ) + pod_list = MagicMock() + pod_list.items = [pod] + runner._core_api.list_namespaced_pod.return_value = pod_list + + assert await runner.has_capacity() is True + + @pytest.mark.asyncio + async def test_pending_no_conditions_returns_true(self): + """Pods in Pending with no conditions yet (just created) should not block.""" + runner = _make_runner() + runner._core_api = AsyncMock() + + pod = _make_pod(phase="Pending", conditions=None) + pod_list = MagicMock() + pod_list.items = [pod] + runner._core_api.list_namespaced_pod.return_value = pod_list + + assert await runner.has_capacity() is True + + @pytest.mark.asyncio + async def test_unschedulable_pod_returns_false(self): + """A pod with PodScheduled=False reason=Unschedulable means the cluster is full.""" + runner = _make_runner() + runner._core_api = AsyncMock() + + pod = _make_pod( + phase="Pending", + conditions=[ + _make_condition(type_="PodScheduled", reason="Unschedulable", status="False"), + ], + ) + pod_list = MagicMock() + pod_list.items = [pod] + runner._core_api.list_namespaced_pod.return_value = pod_list + + assert await runner.has_capacity() is False + + @pytest.mark.asyncio + async def test_api_failure_assumes_capacity(self): + """If the K8s API fails, assume capacity (don't block on transient API errors).""" + runner = _make_runner() + runner._core_api = AsyncMock() + runner._core_api.list_namespaced_pod.side_effect = Exception("API timeout") + + assert await runner.has_capacity() is True + + # --------------------------------------------------------------------------- # Job spec differences (T021) # --------------------------------------------------------------------------- @@ -214,6 +304,8 @@ async def test_successful_run_with_records(self, tmp_path: Path): batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api # No existing jobs job_list = MagicMock() @@ -259,14 +351,7 @@ async def s3_get(key: str) -> bytes: runner._s3.get_object.side_effect = s3_get inputs = IngesterInputs(convention_srn=_CONV_SRN) - result = await runner._run_job( - batch_api, - core_api, - ingester, - inputs, - work_dir, - files_dir, - ) + result = await runner._run_job(ingester, inputs, work_dir, files_dir) assert len(result.records) == 2 assert result.session == {"cursor": "abc"} @@ -274,12 +359,14 @@ async def s3_get(key: str) -> bytes: batch_api.delete_namespaced_job.assert_called_once() @pytest.mark.asyncio - async def test_timeout_raises_external_service_error(self, tmp_path: Path): + async def test_timeout_raises_transient_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -311,23 +398,18 @@ async def test_timeout_raises_external_service_error(self, tmp_path: Path): files_dir.mkdir(parents=True) inputs = IngesterInputs(convention_srn=_CONV_SRN) - with pytest.raises(ExternalServiceError, match="[Tt]imed out|[Dd]eadline"): - await runner._run_job( - batch_api, - core_api, - ingester, - inputs, - work_dir, - files_dir, - ) + with pytest.raises(TransientError, match="[Tt]imed out|[Dd]eadline"): + await runner._run_job(ingester, inputs, work_dir, files_dir) @pytest.mark.asyncio - async def test_oom_raises_external_service_error(self, tmp_path: Path): + async def test_oom_raises_oom_error(self, tmp_path: Path): config = _make_config(data_mount_path=str(tmp_path)) runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api job_list = MagicMock() job_list.items = [] @@ -371,15 +453,8 @@ async def test_oom_raises_external_service_error(self, tmp_path: Path): files_dir.mkdir(parents=True) inputs = IngesterInputs(convention_srn=_CONV_SRN) - with pytest.raises(ExternalServiceError, match="[Oo]OM"): - await runner._run_job( - batch_api, - core_api, - ingester, - inputs, - work_dir, - files_dir, - ) + with pytest.raises(OOMError, match="[Oo]OM"): + await runner._run_job(ingester, inputs, work_dir, files_dir) # --------------------------------------------------------------------------- @@ -392,13 +467,13 @@ class TestConventionSrnFromInputs: @pytest.mark.asyncio async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): - from unittest.mock import patch - config = _make_config(data_mount_path=str(tmp_path)) runner = K8sIngesterRunner(api_client=MagicMock(), config=config, s3=_make_s3_mock()) batch_api = AsyncMock() core_api = AsyncMock() + runner._batch_api = batch_api + runner._core_api = core_api # No existing jobs job_list = MagicMock() @@ -432,11 +507,7 @@ async def test_run_uses_convention_srn_from_inputs(self, tmp_path: Path): convention_srn=ConventionSRN.parse("urn:osa:localhost:conv:my-conv@1.0.0") ) - with ( - patch("kubernetes_asyncio.client.BatchV1Api", return_value=batch_api), - patch("kubernetes_asyncio.client.CoreV1Api", return_value=core_api), - ): - await runner.run(ingester, inputs, files_dir, work_dir) + await runner.run(ingester, inputs, files_dir, work_dir) # Verify convention_srn from inputs ends up in the Job labels call_args = batch_api.create_namespaced_job.call_args diff --git a/server/tests/unit/infrastructure/test_oci_hook_runner.py b/server/tests/unit/infrastructure/test_oci_hook_runner.py index 35352bb..33e20a8 100644 --- a/server/tests/unit/infrastructure/test_oci_hook_runner.py +++ b/server/tests/unit/infrastructure/test_oci_hook_runner.py @@ -5,6 +5,7 @@ import pytest +from osa.domain.shared.error import OOMError, PermanentError, TransientError from osa.domain.shared.model.hook import ( ColumnDef, HookDefinition, @@ -233,13 +234,11 @@ async def test_nonzero_exit_returns_failed(self, tmp_path: Path): output_dir = tmp_path / "output" output_dir.mkdir() - result = await runner.run(hook, inputs, output_dir) - - assert result.status == HookStatus.FAILED - assert "exit" in (result.error_message or "").lower() + with pytest.raises(PermanentError, match="[Ee]xit"): + await runner.run(hook, inputs, output_dir) @pytest.mark.asyncio - async def test_oom_killed_returns_oom(self, tmp_path: Path): + async def test_oom_killed_raises_oom_error(self, tmp_path: Path): docker = AsyncMock() container = AsyncMock() docker.containers.create.return_value = container @@ -256,13 +255,11 @@ async def test_oom_killed_returns_oom(self, tmp_path: Path): output_dir = tmp_path / "output" output_dir.mkdir() - result = await runner.run(hook, inputs, output_dir) - - assert result.status == HookStatus.OOM - assert "oom" in (result.error_message or "").lower() + with pytest.raises(OOMError, match="[Oo][Oo][Mm]"): + await runner.run(hook, inputs, output_dir) @pytest.mark.asyncio - async def test_timeout_returns_failed(self, tmp_path: Path): + async def test_timeout_raises_infrastructure_error(self, tmp_path: Path): import asyncio docker = AsyncMock() @@ -286,10 +283,8 @@ async def hang(): output_dir = tmp_path / "output" output_dir.mkdir() - result = await runner.run(hook, inputs, output_dir) - - assert result.status == HookStatus.FAILED - assert "timed out" in (result.error_message or "").lower() + with pytest.raises(TransientError, match="[Tt]imed out"): + await runner.run(hook, inputs, output_dir) @pytest.mark.asyncio async def test_rejection_via_progress(self, tmp_path: Path): @@ -474,7 +469,5 @@ async def test_container_deleted_on_failure(self, tmp_path: Path): output_dir = tmp_path / "output" output_dir.mkdir() - result = await runner.run(hook, inputs, output_dir) - - assert result.status == HookStatus.FAILED - container.delete.assert_called_once_with(force=True) + with pytest.raises(TransientError, match="Docker error"): + await runner.run(hook, inputs, output_dir)