diff --git a/ingestify/infra/event_log/consumer.py b/ingestify/infra/event_log/consumer.py index 696c6b7..ef53818 100644 --- a/ingestify/infra/event_log/consumer.py +++ b/ingestify/infra/event_log/consumer.py @@ -1,12 +1,17 @@ import logging import time -from typing import Callable, Optional +from typing import Callable, List, Optional from sqlalchemy import create_engine, select +from ingestify.domain.models.event.domain_event import DomainEvent + from .event_log import EventLog from .tables import get_tables +OnEventHandler = Callable[[DomainEvent], None] +OnEventsHandler = Callable[[List[DomainEvent]], None] + logger = logging.getLogger(__name__) @@ -77,7 +82,7 @@ def _update_cursor(self, conn, event_id: int) -> None: ) conn.commit() - def _run_once(self, on_event: Callable, batch_size: int = 100) -> int: + def _run_once(self, on_event: OnEventHandler, batch_size: int = 100) -> int: """Returns number of events processed, or -1 if a processing error occurred.""" with self._engine.connect() as conn: self._ensure_reader_state(conn) @@ -102,7 +107,7 @@ def _run_once(self, on_event: Callable, batch_size: int = 100) -> int: def run( self, - on_event: Callable, + on_event: OnEventHandler, poll_interval: Optional[int] = None, batch_size: int = 100, ) -> int: @@ -114,3 +119,58 @@ def run( if poll_interval is None: return 0 time.sleep(poll_interval) + + def _run_batched_once(self, on_events: OnEventsHandler, batch_size: int) -> int: + """Returns number of events processed, or -1 if a processing error occurred. + + on_events receives the full list of DomainEvent instances for this + batch. The cursor advances to the last event's id only after + on_events returns without raising. + """ + with self._engine.connect() as conn: + self._ensure_reader_state(conn) + last_id = self._get_last_event_id(conn) + + rows = self._event_log.fetch_batch(last_id, batch_size) + if not rows: + return 0 + + events = [event for _, event in rows] + try: + on_events(events) + except Exception: + logger.exception( + "Failed to process batch of %d events — cursor NOT advanced", + len(rows), + ) + return -1 + + last_event_id = rows[-1][0] + with self._engine.connect() as conn: + self._update_cursor(conn, last_event_id) + + return len(rows) + + def run_batched( + self, + on_events: OnEventsHandler, + poll_interval: Optional[int] = None, + batch_size: int = 1000, + ) -> int: + """Consume events in batches. + + on_events is called with a List[DomainEvent] per batch. The cursor + advances once per batch, not per event — letting callers parallelize + I/O-bound work within a batch (threads, asyncio, etc.) without + hitting the DB on every event. + + Exit codes match run(): 0 success, 1 processing error. + """ + while True: + count = self._run_batched_once(on_events, batch_size) + if count < 0: + return 1 + if count == 0: + if poll_interval is None: + return 0 + time.sleep(poll_interval)