diff --git a/pyworkflow/__init__.py b/pyworkflow/__init__.py index c05061f..445eafe 100644 --- a/pyworkflow/__init__.py +++ b/pyworkflow/__init__.py @@ -156,6 +156,21 @@ WorkflowRun, ) +# Streams (pub/sub signal pattern) +from pyworkflow.streams import ( + CheckpointBackend, + Signal, + Stream, + StreamConsumer, + StreamStepContext, + emit, + get_checkpoint, + get_current_signal, + save_checkpoint, + stream_step, + stream_workflow, +) + __all__ = [ # Version "__version__", @@ -264,4 +279,16 @@ "get_logger", "bind_workflow_context", "bind_step_context", + # Streams + "stream_workflow", + "stream_step", + "emit", + "Signal", + "Stream", + "StreamStepContext", + "StreamConsumer", + "CheckpointBackend", + "get_current_signal", + "get_checkpoint", + "save_checkpoint", ] diff --git a/pyworkflow/engine/events.py b/pyworkflow/engine/events.py index d9d5c44..ee8284f 100644 --- a/pyworkflow/engine/events.py +++ b/pyworkflow/engine/events.py @@ -53,6 +53,15 @@ class EventType(Enum): CHILD_WORKFLOW_FAILED = "child_workflow.failed" CHILD_WORKFLOW_CANCELLED = "child_workflow.cancelled" + # Stream signal events + SIGNAL_WAIT_STARTED = "signal.wait_started" + SIGNAL_RECEIVED = "signal.received" + SIGNAL_PUBLISHED = "signal.published" + STREAM_STEP_STARTED = "stream_step.started" + STREAM_STEP_COMPLETED = "stream_step.completed" + CHECKPOINT_SAVED = "checkpoint.saved" + CHECKPOINT_LOADED = "checkpoint.loaded" + # Schedule events SCHEDULE_CREATED = "schedule.created" SCHEDULE_UPDATED = "schedule.updated" @@ -936,3 +945,134 @@ def create_schedule_backfill_completed_event( "completed_at": datetime.now(UTC).isoformat(), }, ) + + +# Stream signal event creation helpers + + +def create_signal_wait_started_event( + run_id: str, + stream_id: str, + signal_types: list[str], + wait_sequence: int = 0, +) -> Event: + """Create a signal wait started event (stream_step waiting for signals).""" + return Event( + run_id=run_id, + type=EventType.SIGNAL_WAIT_STARTED, + data={ + "stream_id": stream_id, + "signal_types": signal_types, + "wait_sequence": wait_sequence, + "token": f"stream:{stream_id}:{run_id}:{wait_sequence}", + "started_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_signal_received_event( + run_id: str, + signal_id: str, + stream_id: str, + signal_type: str, + payload: Any, +) -> Event: + """Create a signal received event (stream_step received a signal).""" + return Event( + run_id=run_id, + type=EventType.SIGNAL_RECEIVED, + data={ + "signal_id": signal_id, + "stream_id": stream_id, + "signal_type": signal_type, + "payload": payload, + "received_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_signal_published_event( + run_id: str, + signal_id: str, + stream_id: str, + signal_type: str, +) -> Event: + """Create a signal published event (workflow/step emitted a signal).""" + return Event( + run_id=run_id, + type=EventType.SIGNAL_PUBLISHED, + data={ + "signal_id": signal_id, + "stream_id": stream_id, + "signal_type": signal_type, + "published_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_stream_step_started_event( + run_id: str, + stream_id: str, + step_name: str, + signal_types: list[str], +) -> Event: + """Create a stream step started event.""" + return Event( + run_id=run_id, + type=EventType.STREAM_STEP_STARTED, + data={ + "stream_id": stream_id, + "step_name": step_name, + "signal_types": signal_types, + "started_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_stream_step_completed_event( + run_id: str, + stream_id: str, + step_name: str, + reason: str | None = None, +) -> Event: + """Create a stream step completed event.""" + return Event( + run_id=run_id, + type=EventType.STREAM_STEP_COMPLETED, + data={ + "stream_id": stream_id, + "step_name": step_name, + "reason": reason, + "completed_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_checkpoint_saved_event( + run_id: str, + step_run_id: str, +) -> Event: + """Create a checkpoint saved event.""" + return Event( + run_id=run_id, + type=EventType.CHECKPOINT_SAVED, + data={ + "step_run_id": step_run_id, + "saved_at": datetime.now(UTC).isoformat(), + }, + ) + + +def create_checkpoint_loaded_event( + run_id: str, + step_run_id: str, +) -> Event: + """Create a checkpoint loaded event.""" + return Event( + run_id=run_id, + type=EventType.CHECKPOINT_LOADED, + data={ + "step_run_id": step_run_id, + "loaded_at": datetime.now(UTC).isoformat(), + }, + ) diff --git a/pyworkflow/engine/replay.py b/pyworkflow/engine/replay.py index a83b20e..b61953b 100644 --- a/pyworkflow/engine/replay.py +++ b/pyworkflow/engine/replay.py @@ -96,6 +96,12 @@ async def _apply_event(self, ctx: LocalContext, event: Event) -> None: elif event.type == EventType.CANCELLATION_REQUESTED: await self._apply_cancellation_requested(ctx, event) + elif event.type == EventType.SIGNAL_WAIT_STARTED: + await self._apply_signal_wait_started(ctx, event) + + elif event.type == EventType.SIGNAL_RECEIVED: + await self._apply_signal_received(ctx, event) + # Other event types don't affect replay state # (workflow_started, step_started, step_failed, etc. are informational) @@ -253,6 +259,49 @@ async def _apply_workflow_interrupted(self, ctx: LocalContext, event: Event) -> last_event_sequence=last_event_sequence, ) + async def _apply_signal_wait_started(self, ctx: LocalContext, event: Event) -> None: + """Apply signal_wait_started event - mark stream step as waiting for signals.""" + stream_id = event.data.get("stream_id") + signal_types = event.data.get("signal_types", []) + wait_sequence = event.data.get("wait_sequence", 0) + + if stream_id: + # Store signal wait state in context for replay + if not hasattr(ctx, "_signal_waits"): + ctx._signal_waits = {} + ctx._signal_waits[f"{stream_id}:{wait_sequence}"] = { + "stream_id": stream_id, + "signal_types": signal_types, + "wait_sequence": wait_sequence, + } + logger.debug( + f"Signal wait pending: {stream_id} (seq {wait_sequence})", + run_id=ctx.run_id, + stream_id=stream_id, + ) + + async def _apply_signal_received(self, ctx: LocalContext, event: Event) -> None: + """Apply signal_received event - cache the received signal data.""" + signal_id = event.data.get("signal_id") + stream_id = event.data.get("stream_id") + signal_type = event.data.get("signal_type") + payload = event.data.get("payload") + + if signal_id: + if not hasattr(ctx, "_received_signals"): + ctx._received_signals = {} + ctx._received_signals[signal_id] = { + "signal_id": signal_id, + "stream_id": stream_id, + "signal_type": signal_type, + "payload": payload, + } + logger.debug( + f"Signal received: {signal_id} ({signal_type})", + run_id=ctx.run_id, + signal_id=signal_id, + ) + async def _apply_cancellation_requested(self, ctx: LocalContext, event: Event) -> None: """ Apply cancellation_requested event - mark workflow for cancellation. diff --git a/pyworkflow/storage/base.py b/pyworkflow/storage/base.py index 547479a..f627bd5 100644 --- a/pyworkflow/storage/base.py +++ b/pyworkflow/storage/base.py @@ -716,6 +716,227 @@ async def delete_old_runs(self, older_than: datetime) -> int: """ pass + # Stream Operations + + async def create_stream(self, stream_id: str, metadata: dict | None = None) -> None: + """ + Create a new stream. + + Args: + stream_id: Unique stream identifier + metadata: Optional stream metadata + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def get_stream(self, stream_id: str) -> dict | None: + """ + Get a stream by ID. + + Args: + stream_id: Stream identifier + + Returns: + Stream data dict if found, None otherwise + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + # Signal Operations + + async def publish_signal( + self, + signal_id: str, + stream_id: str, + signal_type: str, + payload: dict, + source_run_id: str | None = None, + metadata: dict | None = None, + ) -> int: + """ + Publish a signal to a stream. Assigns and returns a sequence number. + + Args: + signal_id: Unique signal identifier + stream_id: Target stream identifier + signal_type: Signal type (e.g., "task.created") + payload: Signal payload data + source_run_id: Optional source workflow run ID + metadata: Optional signal metadata + + Returns: + Assigned sequence number + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def get_signals( + self, + stream_id: str, + after_sequence: int = 0, + limit: int = 100, + ) -> list[dict]: + """ + Get signals from a stream after a given sequence number. + + Args: + stream_id: Stream identifier + after_sequence: Return signals with sequence > this value + limit: Maximum number of signals to return + + Returns: + List of signal dicts ordered by sequence + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + # Subscription Operations + + async def register_stream_subscription( + self, + stream_id: str, + step_run_id: str, + signal_types: list[str], + ) -> None: + """ + Register a stream step's subscription to signal types. + + Args: + stream_id: Stream identifier + step_run_id: The stream step's run ID + signal_types: List of signal types to subscribe to + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def get_waiting_steps( + self, + stream_id: str, + signal_type: str, + ) -> list[dict]: + """ + Get step_run_ids waiting for a specific signal type on a stream. + + Args: + stream_id: Stream identifier + signal_type: Signal type to match + + Returns: + List of dicts with step subscription info + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def update_subscription_status( + self, + stream_id: str, + step_run_id: str, + status: str, + ) -> None: + """ + Update a subscription's status. + + Args: + stream_id: Stream identifier + step_run_id: The stream step's run ID + status: New status ("waiting", "running", "completed") + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def acknowledge_signal( + self, + signal_id: str, + step_run_id: str, + ) -> None: + """ + Acknowledge that a signal has been processed by a step. + + Args: + signal_id: Signal identifier + step_run_id: The step that processed the signal + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + async def get_pending_signals( + self, + stream_id: str, + step_run_id: str, + ) -> list[dict]: + """ + Get signals that arrived for a step but haven't been acknowledged. + + Args: + stream_id: Stream identifier + step_run_id: The stream step's run ID + + Returns: + List of unacknowledged signal dicts + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support streams") + + # Checkpoint Operations + + async def save_checkpoint(self, step_run_id: str, data: dict) -> None: + """ + Save checkpoint data for a stream step. + + Args: + step_run_id: The stream step's run ID + data: Checkpoint data to persist + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support checkpoints") + + async def load_checkpoint(self, step_run_id: str) -> dict | None: + """ + Load checkpoint data for a stream step. + + Args: + step_run_id: The stream step's run ID + + Returns: + Checkpoint data dict if found, None otherwise + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support checkpoints") + + async def delete_checkpoint(self, step_run_id: str) -> None: + """ + Delete checkpoint data for a stream step. + + Args: + step_run_id: The stream step's run ID + + Raises: + NotImplementedError: If backend doesn't support streams + """ + raise NotImplementedError("This storage backend does not support checkpoints") + async def health_check(self) -> bool: """ Check if storage backend is healthy and accessible. diff --git a/pyworkflow/storage/memory.py b/pyworkflow/storage/memory.py index 615f974..2ca0d90 100644 --- a/pyworkflow/storage/memory.py +++ b/pyworkflow/storage/memory.py @@ -51,6 +51,14 @@ def __init__(self) -> None: self._lock = threading.RLock() self._event_sequences: dict[str, int] = {} # run_id -> next sequence + # Stream storage + self._streams: dict[str, dict] = {} # stream_id -> stream data + self._signals: dict[str, list[dict]] = {} # stream_id -> list of signal dicts + self._signal_sequences: dict[str, int] = {} # stream_id -> next sequence + self._subscriptions: dict[tuple[str, str], dict] = {} # (stream_id, step_run_id) -> sub + self._acknowledgments: set[tuple[str, str]] = set() # (signal_id, step_run_id) + self._checkpoints: dict[str, dict] = {} # step_run_id -> checkpoint data + # Workflow Run Operations async def create_run(self, run: WorkflowRun) -> None: @@ -629,6 +637,158 @@ async def delete_old_runs(self, older_than: datetime) -> int: self._cancellation_flags.pop(run_id, None) return len(to_delete) + # Stream Operations + + async def create_stream(self, stream_id: str, metadata: dict | None = None) -> None: + """Create a new stream.""" + with self._lock: + if stream_id in self._streams: + raise ValueError(f"Stream {stream_id} already exists") + self._streams[stream_id] = { + "stream_id": stream_id, + "status": "active", + "created_at": datetime.now(UTC).isoformat(), + "metadata": metadata or {}, + } + self._signals[stream_id] = [] + self._signal_sequences[stream_id] = 0 + + async def get_stream(self, stream_id: str) -> dict | None: + """Get a stream by ID.""" + with self._lock: + return self._streams.get(stream_id) + + async def publish_signal( + self, + signal_id: str, + stream_id: str, + signal_type: str, + payload: dict, + source_run_id: str | None = None, + metadata: dict | None = None, + ) -> int: + """Publish a signal to a stream.""" + with self._lock: + if stream_id not in self._signals: + self._signals[stream_id] = [] + self._signal_sequences[stream_id] = 0 + + seq = self._signal_sequences[stream_id] + self._signal_sequences[stream_id] += 1 + + signal_data = { + "signal_id": signal_id, + "stream_id": stream_id, + "signal_type": signal_type, + "payload": payload, + "published_at": datetime.now(UTC).isoformat(), + "sequence": seq, + "source_run_id": source_run_id, + "metadata": metadata or {}, + } + self._signals[stream_id].append(signal_data) + return seq + + async def get_signals( + self, + stream_id: str, + after_sequence: int = 0, + limit: int = 100, + ) -> list[dict]: + """Get signals from a stream after a given sequence number.""" + with self._lock: + signals = self._signals.get(stream_id, []) + filtered = [s for s in signals if (s.get("sequence", 0)) >= after_sequence] + filtered.sort(key=lambda s: s.get("sequence", 0)) + return filtered[:limit] + + async def register_stream_subscription( + self, + stream_id: str, + step_run_id: str, + signal_types: list[str], + ) -> None: + """Register a stream step's subscription to signal types.""" + with self._lock: + key = (stream_id, step_run_id) + self._subscriptions[key] = { + "stream_id": stream_id, + "step_run_id": step_run_id, + "signal_types": signal_types, + "status": "waiting", + "created_at": datetime.now(UTC).isoformat(), + } + + async def get_waiting_steps( + self, + stream_id: str, + signal_type: str, + ) -> list[dict]: + """Get step_run_ids waiting for a specific signal type on a stream.""" + with self._lock: + result = [] + for (sid, _), sub in self._subscriptions.items(): + if sid == stream_id and sub["status"] == "waiting": + if signal_type in sub["signal_types"]: + result.append(sub) + return result + + async def update_subscription_status( + self, + stream_id: str, + step_run_id: str, + status: str, + ) -> None: + """Update a subscription's status.""" + with self._lock: + key = (stream_id, step_run_id) + if key in self._subscriptions: + self._subscriptions[key]["status"] = status + + async def acknowledge_signal( + self, + signal_id: str, + step_run_id: str, + ) -> None: + """Acknowledge that a signal has been processed by a step.""" + with self._lock: + self._acknowledgments.add((signal_id, step_run_id)) + + async def get_pending_signals( + self, + stream_id: str, + step_run_id: str, + ) -> list[dict]: + """Get signals that arrived for a step but haven't been acknowledged.""" + with self._lock: + sub = self._subscriptions.get((stream_id, step_run_id)) + if not sub: + return [] + + signals = self._signals.get(stream_id, []) + pending = [] + for sig in signals: + if sig["signal_type"] in sub["signal_types"]: + if (sig["signal_id"], step_run_id) not in self._acknowledgments: + pending.append(sig) + pending.sort(key=lambda s: s.get("sequence", 0)) + return pending + + async def save_checkpoint(self, step_run_id: str, data: dict) -> None: + """Save checkpoint data for a stream step.""" + with self._lock: + self._checkpoints[step_run_id] = data + + async def load_checkpoint(self, step_run_id: str) -> dict | None: + """Load checkpoint data for a stream step.""" + with self._lock: + return self._checkpoints.get(step_run_id) + + async def delete_checkpoint(self, step_run_id: str) -> None: + """Delete checkpoint data for a stream step.""" + with self._lock: + self._checkpoints.pop(step_run_id, None) + # Utility methods def clear(self) -> None: @@ -647,6 +807,12 @@ def clear(self) -> None: self._token_index.clear() self._cancellation_flags.clear() self._event_sequences.clear() + self._streams.clear() + self._signals.clear() + self._signal_sequences.clear() + self._subscriptions.clear() + self._acknowledgments.clear() + self._checkpoints.clear() def __len__(self) -> int: """Return total number of workflow runs.""" diff --git a/pyworkflow/streams/__init__.py b/pyworkflow/streams/__init__.py new file mode 100644 index 0000000..e9a3720 --- /dev/null +++ b/pyworkflow/streams/__init__.py @@ -0,0 +1,83 @@ +""" +Streams module — pub/sub signal pattern for PyWorkflow. + +Provides event-driven stream workflows where steps react to signals, +can suspend and be resumed by new signals, and emit signals back. + +Public API: + - stream_workflow: Decorator to define a named stream + - stream_step: Decorator to define a reactive stream step + - emit: Publish a signal to a stream + - Signal: Signal dataclass + - Stream: Stream dataclass + - StreamStepContext: Context for on_signal callbacks + - CheckpointBackend: ABC for checkpoint storage + - get_current_signal: Get the signal that triggered current resume + - get_checkpoint: Load saved checkpoint data + - save_checkpoint: Save checkpoint data +""" + +from pyworkflow.streams.checkpoint import ( # noqa: F401 + CheckpointBackend, + DefaultCheckpointBackend, + RedisCheckpointBackend, + configure_checkpoint_backend, + get_checkpoint_backend, + register_checkpoint_backend, + reset_checkpoint_backend, +) +from pyworkflow.streams.consumer import StreamConsumer, poll_once # noqa: F401 +from pyworkflow.streams.context import ( # noqa: F401 + get_checkpoint, + get_current_signal, + save_checkpoint, +) +from pyworkflow.streams.decorator import stream_step, stream_workflow # noqa: F401 +from pyworkflow.streams.emit import emit # noqa: F401 +from pyworkflow.streams.registry import ( # noqa: F401 + StreamMetadata, + StreamStepMetadata, + clear_stream_registry, + get_steps_for_stream, + get_stream, + get_stream_step, + list_stream_steps, + list_streams, +) +from pyworkflow.streams.signal import Signal, Stream # noqa: F401 +from pyworkflow.streams.step_context import StreamStepContext # noqa: F401 + +__all__ = [ + # Decorators + "stream_workflow", + "stream_step", + # Core + "emit", + "Signal", + "Stream", + "StreamStepContext", + # Context primitives + "get_current_signal", + "get_checkpoint", + "save_checkpoint", + # Checkpoint + "CheckpointBackend", + "DefaultCheckpointBackend", + "RedisCheckpointBackend", + "configure_checkpoint_backend", + "get_checkpoint_backend", + "register_checkpoint_backend", + "reset_checkpoint_backend", + # Registry + "StreamMetadata", + "StreamStepMetadata", + "get_stream", + "get_stream_step", + "get_steps_for_stream", + "list_streams", + "list_stream_steps", + "clear_stream_registry", + # Consumer + "StreamConsumer", + "poll_once", +] diff --git a/pyworkflow/streams/checkpoint.py b/pyworkflow/streams/checkpoint.py new file mode 100644 index 0000000..5baaaae --- /dev/null +++ b/pyworkflow/streams/checkpoint.py @@ -0,0 +1,201 @@ +""" +CheckpointBackend ABC and built-in implementations. + +Checkpoints allow stream steps to persist custom state across +suspend/resume cycles. +""" + +from abc import ABC, abstractmethod +from typing import Any + +from loguru import logger + + +class CheckpointBackend(ABC): + """Abstract base class for checkpoint storage backends.""" + + @abstractmethod + async def save(self, step_run_id: str, data: dict) -> None: + """Save checkpoint data for a stream step.""" + ... + + @abstractmethod + async def load(self, step_run_id: str) -> dict | None: + """Load checkpoint data for a stream step.""" + ... + + @abstractmethod + async def delete(self, step_run_id: str) -> None: + """Delete checkpoint data for a stream step.""" + ... + + +class DefaultCheckpointBackend(CheckpointBackend): + """ + Checkpoint backend that uses PyWorkflow's configured StorageBackend. + + This is the default — checkpoint data is stored alongside workflow data. + """ + + def __init__(self, storage: Any = None) -> None: + self._storage = storage + + def _get_storage(self) -> Any: + if self._storage is not None: + return self._storage + from pyworkflow.config import get_config + + config = get_config() + if config.storage is None: + raise RuntimeError( + "No storage backend configured. " "Call pyworkflow.configure(storage=...) first." + ) + return config.storage + + async def save(self, step_run_id: str, data: dict) -> None: + """Save checkpoint using PyWorkflow storage.""" + storage = self._get_storage() + await storage.save_checkpoint(step_run_id, data) + logger.debug(f"Checkpoint saved for {step_run_id}") + + async def load(self, step_run_id: str) -> dict | None: + """Load checkpoint from PyWorkflow storage.""" + storage = self._get_storage() + data = await storage.load_checkpoint(step_run_id) + if data is not None: + logger.debug(f"Checkpoint loaded for {step_run_id}") + return data + + async def delete(self, step_run_id: str) -> None: + """Delete checkpoint from PyWorkflow storage.""" + storage = self._get_storage() + await storage.delete_checkpoint(step_run_id) + logger.debug(f"Checkpoint deleted for {step_run_id}") + + +class RedisCheckpointBackend(CheckpointBackend): + """ + Checkpoint backend using Redis. + + Configured via pyworkflow.configure() or environment variables: + - PYWORKFLOW_CHECKPOINT_BACKEND_URL: Redis connection URL + """ + + def __init__(self, url: str = "redis://localhost:6379/1") -> None: + self._url = url + self._redis: Any = None + + async def _get_redis(self) -> Any: + if self._redis is None: + try: + import redis.asyncio as aioredis + + self._redis = aioredis.from_url(self._url) + except ImportError: + raise RuntimeError( + "redis package required for RedisCheckpointBackend. " + "Install with: pip install redis" + ) + return self._redis + + async def save(self, step_run_id: str, data: dict) -> None: + """Save checkpoint to Redis.""" + import json + + r = await self._get_redis() + key = f"pyworkflow:checkpoint:{step_run_id}" + await r.set(key, json.dumps(data)) + logger.debug(f"Checkpoint saved to Redis for {step_run_id}") + + async def load(self, step_run_id: str) -> dict | None: + """Load checkpoint from Redis.""" + import json + + r = await self._get_redis() + key = f"pyworkflow:checkpoint:{step_run_id}" + raw = await r.get(key) + if raw is None: + return None + logger.debug(f"Checkpoint loaded from Redis for {step_run_id}") + return json.loads(raw) + + async def delete(self, step_run_id: str) -> None: + """Delete checkpoint from Redis.""" + r = await self._get_redis() + key = f"pyworkflow:checkpoint:{step_run_id}" + await r.delete(key) + logger.debug(f"Checkpoint deleted from Redis for {step_run_id}") + + +# Backend registry +_checkpoint_backends: dict[str, type[CheckpointBackend]] = { + "default": DefaultCheckpointBackend, + "redis": RedisCheckpointBackend, +} + +# Active backend instance (lazy-initialized) +_active_backend: CheckpointBackend | None = None +_configured_backend_name: str = "default" +_configured_backend_url: str | None = None + + +def register_checkpoint_backend(name: str, backend_class: type[CheckpointBackend]) -> None: + """Register a custom checkpoint backend.""" + _checkpoint_backends[name] = backend_class + + +def configure_checkpoint_backend( + backend: str = "default", + url: str | None = None, +) -> None: + """ + Configure the checkpoint backend. + + Args: + backend: Backend name ("default", "redis", or custom registered name) + url: Connection URL (for backends that need it) + """ + global _active_backend, _configured_backend_name, _configured_backend_url + _configured_backend_name = backend + _configured_backend_url = url + _active_backend = None # Reset to force re-initialization + + +def get_checkpoint_backend(storage: Any = None) -> CheckpointBackend: + """Get the configured checkpoint backend, creating it if needed.""" + global _active_backend + + if _active_backend is not None: + return _active_backend + + # Check environment variables + import os + + backend_name = os.environ.get("PYWORKFLOW_CHECKPOINT_BACKEND", _configured_backend_name) + backend_url = os.environ.get("PYWORKFLOW_CHECKPOINT_BACKEND_URL", _configured_backend_url) + + backend_class = _checkpoint_backends.get(backend_name) + if backend_class is None: + raise ValueError( + f"Unknown checkpoint backend: {backend_name}. " + f"Available: {list(_checkpoint_backends.keys())}" + ) + + if backend_name == "default": + _active_backend = DefaultCheckpointBackend(storage=storage) + elif backend_name == "redis": + _active_backend = RedisCheckpointBackend(url=backend_url or "redis://localhost:6379/1") + elif backend_url: + _active_backend = backend_class(url=backend_url) # type: ignore[call-arg] + else: + _active_backend = backend_class() + + return _active_backend + + +def reset_checkpoint_backend() -> None: + """Reset the checkpoint backend (for testing).""" + global _active_backend, _configured_backend_name, _configured_backend_url + _active_backend = None + _configured_backend_name = "default" + _configured_backend_url = None diff --git a/pyworkflow/streams/consumer.py b/pyworkflow/streams/consumer.py new file mode 100644 index 0000000..56d1356 --- /dev/null +++ b/pyworkflow/streams/consumer.py @@ -0,0 +1,124 @@ +""" +Background signal delivery consumer. + +Polls for undelivered signals and dispatches them to subscribed stream steps. +Can run as an asyncio loop or integrated with Celery beat. +""" + +import asyncio +import contextlib +from typing import Any + +from loguru import logger + + +class StreamConsumer: + """ + Background consumer that polls for pending signals and dispatches them. + + Ensures signals are delivered even if they arrive while a step is + being processed (missed during the synchronous dispatch in emit()). + """ + + def __init__( + self, + storage: Any, + poll_interval: float = 1.0, + ) -> None: + """ + Initialize the stream consumer. + + Args: + storage: Storage backend + poll_interval: Seconds between poll cycles + """ + self._storage = storage + self._poll_interval = poll_interval + self._running = False + self._task: asyncio.Task | None = None + + async def start(self) -> None: + """Start the consumer loop.""" + if self._running: + return + self._running = True + self._task = asyncio.create_task(self._poll_loop()) + logger.info("Stream consumer started", poll_interval=self._poll_interval) + + async def stop(self) -> None: + """Stop the consumer loop.""" + self._running = False + if self._task: + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + self._task = None + logger.info("Stream consumer stopped") + + async def _poll_loop(self) -> None: + """Main polling loop.""" + while self._running: + try: + await self._process_pending_signals() + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Stream consumer error: {e}") + + await asyncio.sleep(self._poll_interval) + + async def _process_pending_signals(self) -> None: + """Process any pending signals for subscribed steps.""" + from pyworkflow.streams.dispatcher import dispatch_signal + from pyworkflow.streams.signal import Signal + + # Get all active streams + streams = getattr(self._storage, "_streams", {}) + + for stream_id in list(streams.keys()): + # Get all active subscriptions for this stream + subscriptions = [] + if hasattr(self._storage, "_subscriptions"): + for (sid, _), sub in self._storage._subscriptions.items(): + if sid == stream_id and sub["status"] == "waiting": + subscriptions.append(sub) + + if not subscriptions: + continue + + # Get pending signals for each subscription + for sub in subscriptions: + step_run_id = sub["step_run_id"] + pending = await self._storage.get_pending_signals(stream_id, step_run_id) + + for sig_data in pending: + signal = Signal( + signal_id=sig_data["signal_id"], + stream_id=sig_data["stream_id"], + signal_type=sig_data["signal_type"], + payload=sig_data["payload"], + sequence=sig_data.get("sequence"), + source_run_id=sig_data.get("source_run_id"), + metadata=sig_data.get("metadata", {}), + ) + await dispatch_signal(signal, self._storage) + + @property + def is_running(self) -> bool: + """Whether the consumer is currently running.""" + return self._running + + +async def poll_once(storage: Any) -> int: + """ + Run a single poll cycle. Useful for testing or one-shot processing. + + Args: + storage: Storage backend + + Returns: + Number of signals processed + """ + consumer = StreamConsumer(storage) + await consumer._process_pending_signals() + return 0 # Consumer doesn't track count currently diff --git a/pyworkflow/streams/context.py b/pyworkflow/streams/context.py new file mode 100644 index 0000000..fdf5d7e --- /dev/null +++ b/pyworkflow/streams/context.py @@ -0,0 +1,91 @@ +""" +Stream step context primitives. + +Provides get_current_signal(), get_checkpoint(), and save_checkpoint() +for use within stream step lifecycle functions. +""" + +from contextvars import ContextVar +from typing import Any + +from pyworkflow.streams.checkpoint import get_checkpoint_backend +from pyworkflow.streams.signal import Signal + +# Context variables for stream step execution +_current_signal: ContextVar[Signal | None] = ContextVar("_current_signal", default=None) +_current_step_run_id: ContextVar[str | None] = ContextVar("_current_step_run_id", default=None) +_current_stream_id: ContextVar[str | None] = ContextVar("_current_stream_id", default=None) +_current_storage: ContextVar[Any] = ContextVar("_current_storage", default=None) + + +def set_stream_step_context( + step_run_id: str, + stream_id: str, + signal: Signal | None = None, + storage: Any = None, +) -> tuple: + """ + Set the stream step context variables. + + Returns tokens for resetting. + """ + t1 = _current_signal.set(signal) + t2 = _current_step_run_id.set(step_run_id) + t3 = _current_stream_id.set(stream_id) + t4 = _current_storage.set(storage) + return (t1, t2, t3, t4) + + +def reset_stream_step_context(tokens: tuple) -> None: + """Reset the stream step context variables.""" + t1, t2, t3, t4 = tokens + _current_signal.reset(t1) + _current_step_run_id.reset(t2) + _current_stream_id.reset(t3) + _current_storage.reset(t4) + + +async def get_current_signal() -> Signal | None: + """ + Get the signal that triggered the current lifecycle resume. + + Returns None on first start (initialization phase). + Returns the Signal that caused ctx.resume() on subsequent runs. + """ + return _current_signal.get() + + +async def get_checkpoint() -> dict | None: + """ + Load saved checkpoint data for the current stream step. + + Returns None if no checkpoint has been saved yet. + """ + step_run_id = _current_step_run_id.get() + if step_run_id is None: + raise RuntimeError( + "get_checkpoint() must be called within a stream step lifecycle function." + ) + storage = _current_storage.get() + backend = get_checkpoint_backend(storage=storage) + return await backend.load(step_run_id) + + +async def save_checkpoint(data: dict) -> None: + """ + Save checkpoint data for the current stream step. + + This data will be available via get_checkpoint() after the step + is resumed. + + Args: + data: Dictionary of checkpoint data to persist + """ + step_run_id = _current_step_run_id.get() + if step_run_id is None: + raise RuntimeError( + "save_checkpoint() must be called within a stream step lifecycle function." + ) + storage = _current_storage.get() + backend = get_checkpoint_backend(storage=storage) + await backend.save(step_run_id, data) diff --git a/pyworkflow/streams/decorator.py b/pyworkflow/streams/decorator.py new file mode 100644 index 0000000..e31a7f9 --- /dev/null +++ b/pyworkflow/streams/decorator.py @@ -0,0 +1,131 @@ +""" +@stream_workflow and @stream_step decorators for defining streams and reactive steps. + +@stream_workflow defines a named stream channel. +@stream_step defines a long-lived reactive step that subscribes to signals on a stream. +""" + +import functools +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel + +from pyworkflow.streams.registry import register_stream, register_stream_step + + +def stream_workflow( + name: str | None = None, + **metadata: Any, +) -> Callable: + """ + Decorator to define a stream (named channel for signals). + + Args: + name: Stream name (defaults to function name) + **metadata: Additional stream metadata + + Returns: + Decorated function with stream metadata + + Examples: + @stream_workflow(name="agent_comms") + async def agent_communication(): + pass + """ + + def decorator(func: Callable) -> Callable: + stream_name = name or func.__name__ + + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + register_stream( + name=stream_name, + func=wrapper, + original_func=func, + metadata=metadata, + ) + + wrapper.__stream_workflow__ = True # type: ignore[attr-defined] + wrapper.__stream_name__ = stream_name # type: ignore[attr-defined] + + return wrapper + + return decorator + + +def stream_step( + stream: str, + signals: list[str] | dict[str, type[BaseModel]], + on_signal: Callable[..., Any], + name: str | None = None, +) -> Callable: + """ + Decorator to define a stream step (long-lived reactive unit). + + A stream step has two code paths: + 1. The on_signal callback: runs on every matching signal arrival + 2. The lifecycle function (decorated): runs on start and each explicit resume + + Args: + stream: Name of the stream to subscribe to + signals: Signal types to subscribe to. Either: + - list[str]: Signal type names + - dict[str, BaseModel]: Signal type -> Pydantic schema mapping + on_signal: Async callback for signal processing + name: Step name (defaults to function name) + + Returns: + Decorated function with stream step metadata + + Examples: + async def handle_signal(signal, ctx): + if signal.signal_type == "task.created": + await ctx.resume() + + @stream_step( + stream="agent_comms", + signals=["task.created", "task.updated"], + on_signal=handle_signal, + ) + async def task_planner(): + signal = await get_current_signal() + if signal: + await process(signal) + """ + + # Parse signal types and schemas + if isinstance(signals, dict): + signal_types = list(signals.keys()) + signal_schemas = signals + else: + signal_types = list(signals) + signal_schemas = {} + + def decorator(func: Callable) -> Callable: + step_name = name or func.__name__ + + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + return await func(*args, **kwargs) + + register_stream_step( + name=step_name, + func=wrapper, + original_func=func, + stream=stream, + signal_types=signal_types, + on_signal=on_signal, + signal_schemas=signal_schemas, + ) + + wrapper.__stream_step__ = True # type: ignore[attr-defined] + wrapper.__stream_step_name__ = step_name # type: ignore[attr-defined] + wrapper.__stream_name__ = stream # type: ignore[attr-defined] + wrapper.__signal_types__ = signal_types # type: ignore[attr-defined] + + return wrapper + + return decorator diff --git a/pyworkflow/streams/dispatcher.py b/pyworkflow/streams/dispatcher.py new file mode 100644 index 0000000..8379cd7 --- /dev/null +++ b/pyworkflow/streams/dispatcher.py @@ -0,0 +1,233 @@ +""" +Signal dispatcher for matching signals to subscribed stream steps. + +When a signal is published, the dispatcher: +1. Finds all stream steps subscribed to the signal type +2. Invokes their on_signal callbacks +3. If ctx.resume() was called, triggers workflow resume +""" + +from typing import Any + +from loguru import logger +from pydantic import ValidationError + +from pyworkflow.streams.registry import list_stream_steps +from pyworkflow.streams.signal import Signal +from pyworkflow.streams.step_context import StreamStepContext + + +async def dispatch_signal(signal: Signal, storage: Any) -> None: + """ + Dispatch a signal to all subscribed stream steps. + + 1. Find steps subscribed to this signal_type on this stream + 2. Validate payload against schema if defined + 3. Invoke on_signal callback for each + 4. If ctx.resume() was called, trigger workflow resume + + Args: + signal: The published signal + storage: Storage backend + """ + # Find waiting steps from storage + waiting_steps = await storage.get_waiting_steps(signal.stream_id, signal.signal_type) + + if not waiting_steps: + logger.debug( + f"No waiting steps for {signal.signal_type} on {signal.stream_id}", + stream_id=signal.stream_id, + signal_type=signal.signal_type, + ) + return + + for step_info in waiting_steps: + step_run_id = step_info["step_run_id"] + + try: + await _dispatch_to_step(signal, step_run_id, storage) + except Exception as e: + logger.error( + f"Error dispatching signal to step {step_run_id}: {e}", + signal_id=signal.signal_id, + step_run_id=step_run_id, + ) + + +async def _dispatch_to_step( + signal: Signal, + step_run_id: str, + storage: Any, +) -> None: + """ + Dispatch a signal to a specific stream step. + + Args: + signal: The signal to dispatch + step_run_id: The target step's run ID + storage: Storage backend + """ + # Find the stream step metadata by looking up the step name + # The step_run_id format encodes the step name + step_meta = _find_step_metadata_for_run(step_run_id, signal.stream_id) + + if step_meta is None: + logger.warning( + f"No registered stream step found for run {step_run_id}", + step_run_id=step_run_id, + ) + # Still acknowledge the signal even without metadata + await storage.acknowledge_signal(signal.signal_id, step_run_id) + return + + # Validate payload against schema if defined + validated_signal = signal + if signal.signal_type in step_meta.signal_schemas: + schema_class = step_meta.signal_schemas[signal.signal_type] + try: + validated_payload = schema_class.model_validate(signal.payload) + validated_signal = Signal( + signal_id=signal.signal_id, + stream_id=signal.stream_id, + signal_type=signal.signal_type, + payload=validated_payload, + published_at=signal.published_at, + sequence=signal.sequence, + source_run_id=signal.source_run_id, + metadata=signal.metadata, + ) + except ValidationError as e: + logger.error( + f"Signal payload validation failed for {signal.signal_type}: {e}", + signal_id=signal.signal_id, + signal_type=signal.signal_type, + ) + await storage.acknowledge_signal(signal.signal_id, step_run_id) + return + + # Create StreamStepContext for the callback + ctx = StreamStepContext( + status="suspended", + run_id=step_run_id, + stream_id=signal.stream_id, + storage=storage, + ) + + # Invoke the on_signal callback + try: + await step_meta.on_signal(validated_signal, ctx) + except Exception as e: + logger.error( + f"on_signal callback error for step {step_run_id}: {e}", + signal_id=signal.signal_id, + step_run_id=step_run_id, + ) + + # Acknowledge signal + await storage.acknowledge_signal(signal.signal_id, step_run_id) + + # Record SIGNAL_RECEIVED event + try: + from pyworkflow.engine.events import create_signal_received_event + from pyworkflow.serialization.encoder import serialize + + event = create_signal_received_event( + run_id=step_run_id, + signal_id=signal.signal_id, + stream_id=signal.stream_id, + signal_type=signal.signal_type, + payload=serialize(signal.payload), + ) + await storage.record_event(event) + except Exception: + pass # Best-effort event logging + + # Handle cancellation + if ctx.is_cancelled: + await _cancel_step(step_run_id, ctx.cancel_reason, storage) + return + + # If resume was requested, trigger workflow resume + if ctx.should_resume: + await _resume_step(step_run_id, validated_signal, storage) + + +async def _resume_step( + step_run_id: str, + signal: Signal, + storage: Any, +) -> None: + """ + Resume a stream step's lifecycle after on_signal called ctx.resume(). + + This follows the same pattern as resume_hook(). + """ + logger.info( + f"Resuming stream step {step_run_id}", + step_run_id=step_run_id, + signal_type=signal.signal_type, + ) + + # Update subscription status to running + await storage.update_subscription_status(signal.stream_id, step_run_id, "running") + + # Schedule workflow resumption via configured runtime + try: + from pyworkflow.config import get_config + from pyworkflow.runtime import get_runtime + + config = get_config() + runtime = get_runtime(config.default_runtime) + await runtime.schedule_resume(step_run_id, storage) + except Exception as e: + logger.warning( + f"Failed to schedule stream step resumption: {e}", + step_run_id=step_run_id, + ) + + +async def _cancel_step( + step_run_id: str, + reason: str | None, + storage: Any, +) -> None: + """Cancel a stream step.""" + logger.info( + f"Cancelling stream step {step_run_id}", + step_run_id=step_run_id, + reason=reason, + ) + + try: + from pyworkflow.engine.events import create_stream_step_completed_event + from pyworkflow.storage.schemas import RunStatus + + # Record completion event + event = create_stream_step_completed_event( + run_id=step_run_id, + stream_id="", # Will be filled from context + step_name="", + reason=f"cancelled: {reason}" if reason else "cancelled", + ) + await storage.record_event(event) + + # Update run status + await storage.update_run_status(step_run_id, RunStatus.CANCELLED) + except Exception as e: + logger.error(f"Error cancelling stream step: {e}", step_run_id=step_run_id) + + +def _find_step_metadata_for_run(step_run_id: str, stream_id: str) -> Any: + """Find registered stream step metadata for a given run and stream.""" + all_steps = list_stream_steps() + for step_meta in all_steps.values(): + if step_meta.stream == stream_id: + # Check if the step_run_id matches this step's pattern + # step_run_ids encode the step name: "stream_step_{step_name}_{uuid}" + if step_meta.name in step_run_id: + return step_meta + # Fallback: return the first step on this stream that subscribes to this signal + for step_meta in all_steps.values(): + if step_meta.stream == stream_id: + return step_meta + return None diff --git a/pyworkflow/streams/emit.py b/pyworkflow/streams/emit.py new file mode 100644 index 0000000..5618136 --- /dev/null +++ b/pyworkflow/streams/emit.py @@ -0,0 +1,157 @@ +""" +emit() function for publishing signals to streams. + +Signals can be emitted from workflows, steps, or externally. +""" + +import uuid +from typing import Any + +from loguru import logger +from pydantic import BaseModel + +from pyworkflow.streams.signal import Signal + + +async def emit( + stream_id: str, + signal_type: str, + payload: Any = None, + *, + storage: Any = None, + metadata: dict[str, Any] | None = None, +) -> Signal: + """ + Publish a signal to a stream. + + Can be called from: + - Workflow code (auto-detects storage from context) + - Step code (auto-detects storage from context) + - External code (requires explicit storage parameter) + + Args: + stream_id: Target stream identifier + signal_type: Signal type (e.g., "task.created") + payload: Signal payload data + storage: Storage backend (uses configured default if not provided) + metadata: Optional signal metadata + + Returns: + The published Signal with assigned sequence number + + Examples: + # From workflow code + await emit("agent_comms", "task.created", {"task_id": "t1"}) + + # Externally with explicit storage + await emit("agent_comms", "task.created", payload, storage=my_storage) + """ + # Resolve storage + if storage is None: + storage = _resolve_storage() + + if storage is None: + raise RuntimeError( + "No storage backend available. " + "Either pass storage parameter or call pyworkflow.configure(storage=...)" + ) + + # Validate payload against schema if Pydantic model + if isinstance(payload, BaseModel): + payload_data = payload.model_dump() + elif payload is None: + payload_data = {} + else: + payload_data = payload + + # Resolve source run_id from context if available + source_run_id = _get_source_run_id() + + # Create signal + signal_id = f"sig_{uuid.uuid4().hex[:16]}" + + # Publish to storage + sequence = await storage.publish_signal( + signal_id=signal_id, + stream_id=stream_id, + signal_type=signal_type, + payload=payload_data, + source_run_id=source_run_id, + metadata=metadata, + ) + + signal = Signal( + signal_id=signal_id, + stream_id=stream_id, + signal_type=signal_type, + payload=payload_data, + sequence=sequence, + source_run_id=source_run_id, + metadata=metadata or {}, + ) + + logger.info( + f"Signal published: {signal_type} to {stream_id}", + signal_id=signal_id, + stream_id=stream_id, + signal_type=signal_type, + sequence=sequence, + ) + + # Record SIGNAL_PUBLISHED event in caller's workflow log (if in workflow context) + if source_run_id: + try: + from pyworkflow.engine.events import create_signal_published_event + + event = create_signal_published_event( + run_id=source_run_id, + signal_id=signal_id, + stream_id=stream_id, + signal_type=signal_type, + ) + await storage.record_event(event) + except Exception: + pass # Non-critical: event logging is best-effort from caller context + + # Dispatch signal to waiting steps + from pyworkflow.streams.dispatcher import dispatch_signal + + await dispatch_signal(signal, storage) + + return signal + + +def _resolve_storage() -> Any: + """Try to resolve storage from workflow context or global config.""" + # Try workflow context first + try: + from pyworkflow.context import get_context, has_context + + if has_context(): + ctx = get_context() + if hasattr(ctx, "_storage") and ctx._storage is not None: + return ctx._storage + except Exception: + pass + + # Fall back to global config + try: + from pyworkflow.config import get_config + + config = get_config() + return config.storage + except Exception: + return None + + +def _get_source_run_id() -> str | None: + """Try to get the current workflow run_id from context.""" + try: + from pyworkflow.context import get_context, has_context + + if has_context(): + ctx = get_context() + return ctx.run_id + except Exception: + pass + return None diff --git a/pyworkflow/streams/registry.py b/pyworkflow/streams/registry.py new file mode 100644 index 0000000..4d00d26 --- /dev/null +++ b/pyworkflow/streams/registry.py @@ -0,0 +1,190 @@ +""" +Registry for streams and stream steps. + +Tracks all registered stream workflows and stream steps, enabling +lookup by name and subscription tracking. +""" + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +from pydantic import BaseModel + + +@dataclass +class StreamMetadata: + """Metadata for a registered stream.""" + + name: str + func: Callable[..., Any] + original_func: Callable[..., Any] + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class StreamStepMetadata: + """Metadata for a registered stream step.""" + + name: str + func: Callable[..., Any] + original_func: Callable[..., Any] + stream: str + signal_types: list[str] + on_signal: Callable[..., Any] + signal_schemas: dict[str, type[BaseModel]] = field(default_factory=dict) + + +class StreamRegistry: + """ + Global registry for streams and stream steps. + + Tracks @stream_workflow and @stream_step decorated functions. + """ + + def __init__(self) -> None: + self._streams: dict[str, StreamMetadata] = {} + self._stream_steps: dict[str, StreamStepMetadata] = {} + self._steps_by_stream: dict[str, list[str]] = {} # stream_name -> [step_names] + + def register_stream( + self, + name: str, + func: Callable[..., Any], + original_func: Callable[..., Any], + metadata: dict[str, Any] | None = None, + ) -> None: + """Register a stream workflow.""" + if name in self._streams: + existing = self._streams[name] + if existing.original_func is not original_func: + raise ValueError(f"Stream name '{name}' already registered with different function") + return + + self._streams[name] = StreamMetadata( + name=name, + func=func, + original_func=original_func, + metadata=metadata or {}, + ) + + def register_stream_step( + self, + name: str, + func: Callable[..., Any], + original_func: Callable[..., Any], + stream: str, + signal_types: list[str], + on_signal: Callable[..., Any], + signal_schemas: dict[str, type[BaseModel]] | None = None, + ) -> None: + """Register a stream step.""" + if name in self._stream_steps: + existing = self._stream_steps[name] + if existing.original_func is not original_func: + raise ValueError( + f"Stream step name '{name}' already registered with different function" + ) + return + + self._stream_steps[name] = StreamStepMetadata( + name=name, + func=func, + original_func=original_func, + stream=stream, + signal_types=signal_types, + on_signal=on_signal, + signal_schemas=signal_schemas or {}, + ) + + # Track step-to-stream mapping + if stream not in self._steps_by_stream: + self._steps_by_stream[stream] = [] + if name not in self._steps_by_stream[stream]: + self._steps_by_stream[stream].append(name) + + def get_stream(self, name: str) -> StreamMetadata | None: + """Get stream metadata by name.""" + return self._streams.get(name) + + def get_stream_step(self, name: str) -> StreamStepMetadata | None: + """Get stream step metadata by name.""" + return self._stream_steps.get(name) + + def get_steps_for_stream(self, stream_name: str) -> list[StreamStepMetadata]: + """Get all stream steps registered for a given stream.""" + step_names = self._steps_by_stream.get(stream_name, []) + return [self._stream_steps[n] for n in step_names if n in self._stream_steps] + + def list_streams(self) -> dict[str, StreamMetadata]: + """Get all registered streams.""" + return self._streams.copy() + + def list_stream_steps(self) -> dict[str, StreamStepMetadata]: + """Get all registered stream steps.""" + return self._stream_steps.copy() + + def clear(self) -> None: + """Clear all registrations (useful for testing).""" + self._streams.clear() + self._stream_steps.clear() + self._steps_by_stream.clear() + + +# Global singleton registry +_stream_registry = StreamRegistry() + + +def register_stream( + name: str, + func: Callable[..., Any], + original_func: Callable[..., Any], + metadata: dict[str, Any] | None = None, +) -> None: + """Register a stream in the global registry.""" + _stream_registry.register_stream(name, func, original_func, metadata) + + +def register_stream_step( + name: str, + func: Callable[..., Any], + original_func: Callable[..., Any], + stream: str, + signal_types: list[str], + on_signal: Callable[..., Any], + signal_schemas: dict[str, type[BaseModel]] | None = None, +) -> None: + """Register a stream step in the global registry.""" + _stream_registry.register_stream_step( + name, func, original_func, stream, signal_types, on_signal, signal_schemas + ) + + +def get_stream(name: str) -> StreamMetadata | None: + """Get stream metadata from global registry.""" + return _stream_registry.get_stream(name) + + +def get_stream_step(name: str) -> StreamStepMetadata | None: + """Get stream step metadata from global registry.""" + return _stream_registry.get_stream_step(name) + + +def get_steps_for_stream(stream_name: str) -> list[StreamStepMetadata]: + """Get all stream steps for a stream from global registry.""" + return _stream_registry.get_steps_for_stream(stream_name) + + +def list_streams() -> dict[str, StreamMetadata]: + """List all streams in global registry.""" + return _stream_registry.list_streams() + + +def list_stream_steps() -> dict[str, StreamStepMetadata]: + """List all stream steps in global registry.""" + return _stream_registry.list_stream_steps() + + +def clear_stream_registry() -> None: + """Clear the global stream registry (for testing).""" + _stream_registry.clear() diff --git a/pyworkflow/streams/signal.py b/pyworkflow/streams/signal.py new file mode 100644 index 0000000..19ff408 --- /dev/null +++ b/pyworkflow/streams/signal.py @@ -0,0 +1,70 @@ +""" +Signal dataclass for typed messages published to streams. + +Signals are the message units in the stream pub/sub system. Each signal +has a type, payload, and belongs to a specific stream. +""" + +import uuid +from dataclasses import dataclass, field +from datetime import UTC, datetime +from typing import Any + + +@dataclass +class Signal: + """ + A typed message published to a stream. + + Signals are immutable records that flow through streams. Each signal + has a unique ID, belongs to a stream, and carries a typed payload. + + Attributes: + signal_id: Unique identifier for this signal + stream_id: The stream this signal belongs to + signal_type: Type identifier (e.g., "task.created", "result.ready") + payload: Signal data (validated against schema if configured) + published_at: Timestamp when the signal was published + sequence: Ordering sequence within the stream (assigned by storage) + source_run_id: Optional run_id of the workflow that emitted this signal + metadata: Optional additional metadata + """ + + signal_id: str = field(default_factory=lambda: f"sig_{uuid.uuid4().hex[:16]}") + stream_id: str = "" + signal_type: str = "" + payload: Any = field(default_factory=dict) + published_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + sequence: int | None = None # Assigned by storage layer + source_run_id: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate signal after initialization.""" + if not self.stream_id: + raise ValueError("Signal must have a stream_id") + if not self.signal_type: + raise ValueError("Signal must have a signal_type") + + +@dataclass +class Stream: + """ + A named, durable channel for signals. + + Attributes: + stream_id: Unique identifier for this stream + status: Current stream status ("active", "paused", "closed") + created_at: Timestamp when the stream was created + metadata: Optional stream configuration and metadata + """ + + stream_id: str = "" + status: str = "active" + created_at: datetime = field(default_factory=lambda: datetime.now(UTC)) + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + """Validate stream after initialization.""" + if not self.stream_id: + raise ValueError("Stream must have a stream_id") diff --git a/pyworkflow/streams/step_context.py b/pyworkflow/streams/step_context.py new file mode 100644 index 0000000..5b86dac --- /dev/null +++ b/pyworkflow/streams/step_context.py @@ -0,0 +1,79 @@ +""" +StreamStepContext for on_signal callbacks. + +Provides status information and management APIs (resume, cancel) +for stream steps processing signals. +""" + +from loguru import logger + + +class StreamStepContext: + """ + Context provided to on_signal callbacks in stream steps. + + Provides: + - Step status information + - resume() to trigger lifecycle function re-execution + - cancel() to terminate the step + """ + + def __init__( + self, + status: str, + run_id: str, + stream_id: str, + storage: object | None = None, + ) -> None: + self.status = status + self.run_id = run_id + self.stream_id = stream_id + self._storage = storage + self._should_resume = False + self._cancelled = False + self._cancel_reason: str | None = None + + async def resume(self) -> None: + """ + Resume the step's lifecycle function. + + The signal that triggered this on_signal callback will be + available via get_current_signal() in the lifecycle function. + """ + self._should_resume = True + logger.debug( + "Stream step resume requested", + run_id=self.run_id, + stream_id=self.stream_id, + ) + + async def cancel(self, reason: str | None = None) -> None: + """ + Cancel the stream step. + + Args: + reason: Optional reason for cancellation + """ + self._cancelled = True + self._cancel_reason = reason + logger.info( + "Stream step cancel requested", + run_id=self.run_id, + stream_id=self.stream_id, + reason=reason, + ) + + @property + def should_resume(self) -> bool: + """Whether resume() was called during on_signal processing.""" + return self._should_resume + + @property + def is_cancelled(self) -> bool: + """Whether cancel() was called during on_signal processing.""" + return self._cancelled + + @property + def cancel_reason(self) -> str | None: + """Reason for cancellation, if any.""" + return self._cancel_reason diff --git a/tests/integration/test_stream_e2e.py b/tests/integration/test_stream_e2e.py new file mode 100644 index 0000000..69f82a8 --- /dev/null +++ b/tests/integration/test_stream_e2e.py @@ -0,0 +1,308 @@ +"""End-to-end tests for stream step lifecycle.""" + +import pytest + +from pyworkflow.storage.memory import InMemoryStorageBackend +from pyworkflow.streams.checkpoint import reset_checkpoint_backend +from pyworkflow.streams.context import ( + get_checkpoint, + get_current_signal, + reset_stream_step_context, + save_checkpoint, + set_stream_step_context, +) +from pyworkflow.streams.decorator import stream_step +from pyworkflow.streams.emit import emit +from pyworkflow.streams.registry import clear_stream_registry +from pyworkflow.streams.signal import Signal +from pyworkflow.streams.step_context import StreamStepContext + + +@pytest.fixture +def storage(): + """Create a fresh storage backend.""" + return InMemoryStorageBackend() + + +@pytest.fixture(autouse=True) +def clean_state(): + """Clean up registry and checkpoint backend between tests.""" + clear_stream_registry() + reset_checkpoint_backend() + yield + clear_stream_registry() + reset_checkpoint_backend() + + +class TestStreamStepLifecycle: + """Tests for the full stream step lifecycle.""" + + @pytest.mark.asyncio + async def test_on_signal_callback_invoked(self, storage): + """on_signal should be called when a matching signal arrives.""" + received = [] + + async def on_signal(signal, ctx): + received.append(signal.signal_type) + + @stream_step(stream="s", signals=["task.created"], on_signal=on_signal) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["task.created"]) + + await emit("s", "task.created", {"id": 1}, storage=storage) + assert received == ["task.created"] + + @pytest.mark.asyncio + async def test_on_signal_resume_triggers_lifecycle(self, storage): + """ctx.resume() in on_signal should mark step for resumption.""" + resume_called = [] + + async def on_signal(signal, ctx): + resume_called.append(True) + await ctx.resume() + + @stream_step(stream="s", signals=["task.created"], on_signal=on_signal) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["task.created"]) + + await emit("s", "task.created", {"id": 1}, storage=storage) + assert len(resume_called) == 1 + + @pytest.mark.asyncio + async def test_on_signal_no_resume_stays_suspended(self, storage): + """Without ctx.resume(), step should stay suspended.""" + processed = [] + + async def on_signal(signal, ctx): + processed.append(signal.signal_type) + # No ctx.resume() - just process the signal + + @stream_step(stream="s", signals=["event"], on_signal=on_signal) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["event"]) + + await emit("s", "event", {}, storage=storage) + assert len(processed) == 1 + + # Subscription should still be waiting + sub = storage._subscriptions.get(("s", "stream_step_worker_1")) + assert sub is not None + assert sub["status"] == "waiting" + + @pytest.mark.asyncio + async def test_on_signal_cancel(self, storage): + """ctx.cancel() should cancel the step.""" + + async def on_signal(signal, ctx): + await ctx.cancel("test reason") + + @stream_step(stream="s", signals=["event"], on_signal=on_signal) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["event"]) + + await emit("s", "event", {}, storage=storage) + # Cancel was called - step should be marked + + @pytest.mark.asyncio + async def test_multiple_signal_types(self, storage): + """Step should only receive signals it subscribes to.""" + received_types = [] + + async def on_signal(signal, ctx): + received_types.append(signal.signal_type) + + @stream_step( + stream="s", + signals=["task.created", "task.updated"], + on_signal=on_signal, + ) + async def worker(): + pass + + await storage.register_stream_subscription( + "s", "stream_step_worker_1", ["task.created", "task.updated"] + ) + + await emit("s", "task.created", {}, storage=storage) + await emit("s", "task.deleted", {}, storage=storage) # Not subscribed + await emit("s", "task.updated", {}, storage=storage) + + assert received_types == ["task.created", "task.updated"] + + +class TestStreamStepContext: + """Tests for StreamStepContext in on_signal callbacks.""" + + @pytest.mark.asyncio + async def test_context_status(self): + """StreamStepContext should expose status.""" + ctx = StreamStepContext( + status="suspended", + run_id="run_1", + stream_id="s1", + ) + assert ctx.status == "suspended" + assert ctx.run_id == "run_1" + assert ctx.stream_id == "s1" + + @pytest.mark.asyncio + async def test_context_resume(self): + """resume() should set should_resume flag.""" + ctx = StreamStepContext(status="suspended", run_id="r", stream_id="s") + assert not ctx.should_resume + await ctx.resume() + assert ctx.should_resume + + @pytest.mark.asyncio + async def test_context_cancel(self): + """cancel() should set cancelled flag and reason.""" + ctx = StreamStepContext(status="suspended", run_id="r", stream_id="s") + assert not ctx.is_cancelled + await ctx.cancel("done") + assert ctx.is_cancelled + assert ctx.cancel_reason == "done" + + +class TestCheckpointIntegration: + """Tests for checkpoint save/load within stream step context.""" + + @pytest.mark.asyncio + async def test_save_and_load_checkpoint(self, storage): + """Should save and load checkpoint via context functions.""" + tokens = set_stream_step_context( + step_run_id="step_1", + stream_id="s1", + storage=storage, + ) + try: + # No checkpoint initially + data = await get_checkpoint() + assert data is None + + # Save checkpoint + await save_checkpoint({"count": 42, "state": "processing"}) + + # Load it back + data = await get_checkpoint() + assert data == {"count": 42, "state": "processing"} + + # Overwrite + await save_checkpoint({"count": 43, "state": "done"}) + data = await get_checkpoint() + assert data == {"count": 43, "state": "done"} + finally: + reset_stream_step_context(tokens) + + @pytest.mark.asyncio + async def test_get_current_signal_none_on_first_run(self, storage): + """get_current_signal() should return None on first start.""" + tokens = set_stream_step_context( + step_run_id="step_1", + stream_id="s1", + signal=None, + storage=storage, + ) + try: + signal = await get_current_signal() + assert signal is None + finally: + reset_stream_step_context(tokens) + + @pytest.mark.asyncio + async def test_get_current_signal_with_signal(self, storage): + """get_current_signal() should return the signal on resume.""" + test_signal = Signal( + stream_id="s1", + signal_type="task.created", + payload={"task_id": "t1"}, + ) + tokens = set_stream_step_context( + step_run_id="step_1", + stream_id="s1", + signal=test_signal, + storage=storage, + ) + try: + signal = await get_current_signal() + assert signal is not None + assert signal.signal_type == "task.created" + assert signal.payload == {"task_id": "t1"} + finally: + reset_stream_step_context(tokens) + + +class TestSignalPayloadValidation: + """Tests for Pydantic schema validation on signal payloads.""" + + @pytest.mark.asyncio + async def test_valid_payload_schema(self, storage): + """Valid payload should pass schema validation.""" + from pydantic import BaseModel + + class TaskPayload(BaseModel): + task_id: str + description: str + + validated = [] + + async def on_signal(signal, ctx): + validated.append(signal.payload) + + @stream_step( + stream="s", + signals={"task.created": TaskPayload}, + on_signal=on_signal, + ) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["task.created"]) + + await emit( + "s", + "task.created", + {"task_id": "t1", "description": "Test task"}, + storage=storage, + ) + + assert len(validated) == 1 + # Payload was validated as TaskPayload + assert validated[0].task_id == "t1" + assert validated[0].description == "Test task" + + @pytest.mark.asyncio + async def test_invalid_payload_schema(self, storage): + """Invalid payload should be rejected (not dispatched).""" + from pydantic import BaseModel + + class StrictPayload(BaseModel): + required_field: str + + received = [] + + async def on_signal(signal, ctx): + received.append(True) + + @stream_step( + stream="s", + signals={"task.created": StrictPayload}, + on_signal=on_signal, + ) + async def worker(): + pass + + await storage.register_stream_subscription("s", "stream_step_worker_1", ["task.created"]) + + # Missing required_field - should fail validation + await emit("s", "task.created", {"wrong_field": "value"}, storage=storage) + + # on_signal should NOT have been called due to validation failure + assert len(received) == 0 diff --git a/tests/unit/test_emit.py b/tests/unit/test_emit.py new file mode 100644 index 0000000..f436280 --- /dev/null +++ b/tests/unit/test_emit.py @@ -0,0 +1,123 @@ +"""Tests for emit() function and signal dispatch.""" + +import pytest + +from pyworkflow.storage.memory import InMemoryStorageBackend +from pyworkflow.streams.emit import emit +from pyworkflow.streams.registry import clear_stream_registry +from pyworkflow.streams.signal import Signal + + +@pytest.fixture +def storage(): + """Create a fresh storage backend.""" + return InMemoryStorageBackend() + + +@pytest.fixture(autouse=True) +def clean_registry(): + """Clear stream registry before each test.""" + clear_stream_registry() + yield + clear_stream_registry() + + +class TestEmit: + """Tests for the emit() function.""" + + @pytest.mark.asyncio + async def test_emit_basic(self, storage): + """emit() should publish a signal and return it.""" + await storage.create_stream("test_stream") + + signal = await emit( + "test_stream", + "task.created", + {"task_id": "t1"}, + storage=storage, + ) + + assert isinstance(signal, Signal) + assert signal.stream_id == "test_stream" + assert signal.signal_type == "task.created" + assert signal.payload == {"task_id": "t1"} + assert signal.sequence == 0 + + @pytest.mark.asyncio + async def test_emit_increments_sequence(self, storage): + """Each emit should get a monotonically increasing sequence.""" + sig1 = await emit("s", "t", {"n": 1}, storage=storage) + sig2 = await emit("s", "t", {"n": 2}, storage=storage) + + assert sig1.sequence == 0 + assert sig2.sequence == 1 + + @pytest.mark.asyncio + async def test_emit_stores_in_storage(self, storage): + """emit() should persist signals in storage.""" + await emit("s", "task.created", {"id": 1}, storage=storage) + await emit("s", "task.updated", {"id": 1}, storage=storage) + + signals = await storage.get_signals("s") + assert len(signals) == 2 + assert signals[0]["signal_type"] == "task.created" + assert signals[1]["signal_type"] == "task.updated" + + @pytest.mark.asyncio + async def test_emit_with_none_payload(self, storage): + """emit() should handle None payload as empty dict.""" + signal = await emit("s", "event", None, storage=storage) + assert signal.payload == {} + + @pytest.mark.asyncio + async def test_emit_with_pydantic_model(self, storage): + """emit() should serialize Pydantic models.""" + from pydantic import BaseModel + + class TaskPayload(BaseModel): + task_id: str + description: str + + payload = TaskPayload(task_id="t1", description="test") + signal = await emit("s", "task.created", payload, storage=storage) + assert signal.payload == {"task_id": "t1", "description": "test"} + + @pytest.mark.asyncio + async def test_emit_with_metadata(self, storage): + """emit() should pass metadata to signal.""" + signal = await emit( + "s", + "t", + {"data": True}, + storage=storage, + metadata={"priority": "high"}, + ) + assert signal.metadata == {"priority": "high"} + + @pytest.mark.asyncio + async def test_emit_no_storage_raises(self): + """emit() should raise if no storage is available.""" + with pytest.raises(RuntimeError, match="No storage backend"): + await emit("s", "t", {}) + + @pytest.mark.asyncio + async def test_emit_dispatches_to_waiting_steps(self, storage): + """emit() should dispatch signal to waiting steps.""" + from pyworkflow.streams.decorator import stream_step + + callback_signals = [] + + async def on_signal(signal, ctx): + callback_signals.append(signal) + + @stream_step(stream="s", signals=["task.created"], on_signal=on_signal) + async def my_step(): + pass + + # Register subscription + await storage.register_stream_subscription("s", "stream_step_my_step_123", ["task.created"]) + + await emit("s", "task.created", {"id": 1}, storage=storage) + + assert len(callback_signals) == 1 + assert callback_signals[0].signal_type == "task.created" diff --git a/tests/unit/test_signal.py b/tests/unit/test_signal.py new file mode 100644 index 0000000..5b2dec1 --- /dev/null +++ b/tests/unit/test_signal.py @@ -0,0 +1,76 @@ +"""Tests for Signal and Stream dataclasses.""" + +import pytest + +from pyworkflow.streams.signal import Signal, Stream + + +class TestSignal: + """Tests for the Signal dataclass.""" + + def test_create_signal(self): + """Signal should be created with required fields.""" + sig = Signal(stream_id="test_stream", signal_type="task.created") + assert sig.stream_id == "test_stream" + assert sig.signal_type == "task.created" + assert sig.signal_id.startswith("sig_") + assert sig.payload == {} + assert sig.sequence is None + assert sig.source_run_id is None + assert sig.metadata == {} + + def test_signal_with_payload(self): + """Signal should accept arbitrary payload.""" + payload = {"task_id": "t1", "description": "Do something"} + sig = Signal( + stream_id="test_stream", + signal_type="task.created", + payload=payload, + ) + assert sig.payload == payload + + def test_signal_with_metadata(self): + """Signal should accept metadata.""" + sig = Signal( + stream_id="test_stream", + signal_type="task.created", + metadata={"priority": "high"}, + ) + assert sig.metadata == {"priority": "high"} + + def test_signal_requires_stream_id(self): + """Signal must have a stream_id.""" + with pytest.raises(ValueError, match="stream_id"): + Signal(signal_type="task.created") + + def test_signal_requires_signal_type(self): + """Signal must have a signal_type.""" + with pytest.raises(ValueError, match="signal_type"): + Signal(stream_id="test_stream") + + def test_signal_unique_ids(self): + """Each signal should get a unique ID.""" + sig1 = Signal(stream_id="s", signal_type="t") + sig2 = Signal(stream_id="s", signal_type="t") + assert sig1.signal_id != sig2.signal_id + + +class TestStream: + """Tests for the Stream dataclass.""" + + def test_create_stream(self): + """Stream should be created with required fields.""" + stream = Stream(stream_id="test_stream") + assert stream.stream_id == "test_stream" + assert stream.status == "active" + assert stream.metadata == {} + + def test_stream_requires_stream_id(self): + """Stream must have a stream_id.""" + with pytest.raises(ValueError, match="stream_id"): + Stream() + + def test_stream_with_metadata(self): + """Stream should accept metadata.""" + stream = Stream(stream_id="s", metadata={"description": "test"}) + assert stream.metadata == {"description": "test"} diff --git a/tests/unit/test_stream_storage.py b/tests/unit/test_stream_storage.py new file mode 100644 index 0000000..b770aa7 --- /dev/null +++ b/tests/unit/test_stream_storage.py @@ -0,0 +1,188 @@ +"""Tests for stream storage operations on InMemoryStorageBackend.""" + +import pytest + +from pyworkflow.storage.memory import InMemoryStorageBackend + + +@pytest.fixture +def storage(): + """Create a fresh storage backend.""" + return InMemoryStorageBackend() + + +class TestStreamStorage: + """Tests for stream CRUD operations.""" + + @pytest.mark.asyncio + async def test_create_and_get_stream(self, storage): + """Should create and retrieve a stream.""" + await storage.create_stream("test_stream", {"description": "test"}) + stream = await storage.get_stream("test_stream") + assert stream is not None + assert stream["stream_id"] == "test_stream" + assert stream["status"] == "active" + assert stream["metadata"] == {"description": "test"} + + @pytest.mark.asyncio + async def test_get_nonexistent_stream(self, storage): + """Should return None for nonexistent stream.""" + result = await storage.get_stream("nope") + assert result is None + + @pytest.mark.asyncio + async def test_create_duplicate_stream(self, storage): + """Should raise ValueError for duplicate stream_id.""" + await storage.create_stream("dup") + with pytest.raises(ValueError, match="already exists"): + await storage.create_stream("dup") + + +class TestSignalStorage: + """Tests for signal publish/get operations.""" + + @pytest.mark.asyncio + async def test_publish_and_get_signals(self, storage): + """Should publish signals and retrieve them in order.""" + await storage.create_stream("s1") + + seq0 = await storage.publish_signal("sig_1", "s1", "task.created", {"id": 1}) + seq1 = await storage.publish_signal("sig_2", "s1", "task.updated", {"id": 1}) + + assert seq0 == 0 + assert seq1 == 1 + + signals = await storage.get_signals("s1") + assert len(signals) == 2 + assert signals[0]["signal_id"] == "sig_1" + assert signals[1]["signal_id"] == "sig_2" + + @pytest.mark.asyncio + async def test_get_signals_after_sequence(self, storage): + """Should filter signals by sequence number.""" + await storage.publish_signal("s1", "stream", "t", {}) + await storage.publish_signal("s2", "stream", "t", {}) + await storage.publish_signal("s3", "stream", "t", {}) + + signals = await storage.get_signals("stream", after_sequence=1) + assert len(signals) == 2 # seq 1 and 2 + + @pytest.mark.asyncio + async def test_get_signals_with_limit(self, storage): + """Should respect limit parameter.""" + for i in range(5): + await storage.publish_signal(f"s{i}", "stream", "t", {}) + + signals = await storage.get_signals("stream", limit=2) + assert len(signals) == 2 + + @pytest.mark.asyncio + async def test_publish_signal_with_source_run_id(self, storage): + """Should store source_run_id.""" + await storage.publish_signal("s1", "stream", "t", {}, source_run_id="run_123") + signals = await storage.get_signals("stream") + assert signals[0]["source_run_id"] == "run_123" + + +class TestSubscriptionStorage: + """Tests for subscription management.""" + + @pytest.mark.asyncio + async def test_register_and_get_waiting_steps(self, storage): + """Should register subscriptions and find waiting steps.""" + await storage.register_stream_subscription( + "stream_1", "step_run_1", ["task.created", "task.updated"] + ) + + waiting = await storage.get_waiting_steps("stream_1", "task.created") + assert len(waiting) == 1 + assert waiting[0]["step_run_id"] == "step_run_1" + + @pytest.mark.asyncio + async def test_waiting_steps_filters_by_signal_type(self, storage): + """Should only return steps subscribed to the specific signal type.""" + await storage.register_stream_subscription("s", "step_1", ["a", "b"]) + await storage.register_stream_subscription("s", "step_2", ["c"]) + + waiting_a = await storage.get_waiting_steps("s", "a") + assert len(waiting_a) == 1 + assert waiting_a[0]["step_run_id"] == "step_1" + + waiting_c = await storage.get_waiting_steps("s", "c") + assert len(waiting_c) == 1 + assert waiting_c[0]["step_run_id"] == "step_2" + + @pytest.mark.asyncio + async def test_update_subscription_status(self, storage): + """Should update subscription status.""" + await storage.register_stream_subscription("s", "step_1", ["a"]) + await storage.update_subscription_status("s", "step_1", "running") + + # Should not appear in waiting steps + waiting = await storage.get_waiting_steps("s", "a") + assert len(waiting) == 0 + + +class TestAcknowledgmentStorage: + """Tests for signal acknowledgment.""" + + @pytest.mark.asyncio + async def test_acknowledge_signal(self, storage): + """Should track acknowledged signals.""" + await storage.register_stream_subscription("s", "step_1", ["t"]) + await storage.publish_signal("sig_1", "s", "t", {}) + + # Before ack: signal is pending + pending = await storage.get_pending_signals("s", "step_1") + assert len(pending) == 1 + + # After ack: signal is no longer pending + await storage.acknowledge_signal("sig_1", "step_1") + pending = await storage.get_pending_signals("s", "step_1") + assert len(pending) == 0 + + @pytest.mark.asyncio + async def test_pending_signals_only_subscribed_types(self, storage): + """Should only return pending signals for subscribed types.""" + await storage.register_stream_subscription("s", "step_1", ["a"]) + await storage.publish_signal("sig_1", "s", "a", {}) + await storage.publish_signal("sig_2", "s", "b", {}) + + pending = await storage.get_pending_signals("s", "step_1") + assert len(pending) == 1 + assert pending[0]["signal_type"] == "a" + + +class TestCheckpointStorage: + """Tests for checkpoint operations.""" + + @pytest.mark.asyncio + async def test_save_and_load_checkpoint(self, storage): + """Should save and load checkpoint data.""" + data = {"count": 42, "state": "active"} + await storage.save_checkpoint("step_run_1", data) + + loaded = await storage.load_checkpoint("step_run_1") + assert loaded == data + + @pytest.mark.asyncio + async def test_load_nonexistent_checkpoint(self, storage): + """Should return None for nonexistent checkpoint.""" + result = await storage.load_checkpoint("nope") + assert result is None + + @pytest.mark.asyncio + async def test_delete_checkpoint(self, storage): + """Should delete checkpoint data.""" + await storage.save_checkpoint("step_run_1", {"data": True}) + await storage.delete_checkpoint("step_run_1") + result = await storage.load_checkpoint("step_run_1") + assert result is None + + @pytest.mark.asyncio + async def test_overwrite_checkpoint(self, storage): + """Should overwrite existing checkpoint.""" + await storage.save_checkpoint("step_run_1", {"version": 1}) + await storage.save_checkpoint("step_run_1", {"version": 2}) + loaded = await storage.load_checkpoint("step_run_1") + assert loaded == {"version": 2} diff --git a/tests/unit/test_stream_workflow.py b/tests/unit/test_stream_workflow.py new file mode 100644 index 0000000..5c5efb5 --- /dev/null +++ b/tests/unit/test_stream_workflow.py @@ -0,0 +1,203 @@ +"""Tests for stream workflow and stream step decorators and registry.""" + +import pytest + +from pyworkflow.streams.decorator import stream_step, stream_workflow +from pyworkflow.streams.registry import ( + clear_stream_registry, + get_steps_for_stream, + get_stream, + get_stream_step, + list_stream_steps, + list_streams, +) + + +@pytest.fixture(autouse=True) +def clean_registry(): + """Clear the stream registry before each test.""" + clear_stream_registry() + yield + clear_stream_registry() + + +class TestStreamWorkflowDecorator: + """Tests for the @stream_workflow decorator.""" + + def test_basic_stream_workflow(self): + """@stream_workflow should register a stream.""" + + @stream_workflow(name="test_stream") + async def my_stream(): + pass + + assert my_stream.__stream_workflow__ is True + assert my_stream.__stream_name__ == "test_stream" + + def test_stream_workflow_default_name(self): + """@stream_workflow should default to function name.""" + + @stream_workflow() + async def agent_comms(): + pass + + assert agent_comms.__stream_name__ == "agent_comms" + meta = get_stream("agent_comms") + assert meta is not None + assert meta.name == "agent_comms" + + def test_stream_workflow_registry(self): + """Registered streams should be queryable.""" + + @stream_workflow(name="stream_a") + async def a(): + pass + + @stream_workflow(name="stream_b") + async def b(): + pass + + streams = list_streams() + assert "stream_a" in streams + assert "stream_b" in streams + + def test_stream_workflow_duplicate_same_func(self): + """Re-registering same function should be idempotent.""" + + @stream_workflow(name="dup_stream") + async def my_stream(): + pass + + # Should not raise + meta = get_stream("dup_stream") + assert meta is not None + + +class TestStreamStepDecorator: + """Tests for the @stream_step decorator.""" + + def test_basic_stream_step(self): + """@stream_step should register a stream step.""" + + async def handle_signal(signal, ctx): + pass + + @stream_step( + stream="test_stream", + signals=["task.created"], + on_signal=handle_signal, + ) + async def task_planner(): + pass + + assert task_planner.__stream_step__ is True + assert task_planner.__stream_step_name__ == "task_planner" + assert task_planner.__stream_name__ == "test_stream" + assert task_planner.__signal_types__ == ["task.created"] + + def test_stream_step_with_multiple_signals(self): + """@stream_step should accept multiple signal types.""" + + async def handler(signal, ctx): + pass + + @stream_step( + stream="comms", + signals=["task.created", "task.updated", "task.deleted"], + on_signal=handler, + ) + async def multi_handler(): + pass + + meta = get_stream_step("multi_handler") + assert meta is not None + assert meta.signal_types == ["task.created", "task.updated", "task.deleted"] + + def test_stream_step_with_schema_dict(self): + """@stream_step should accept dict mapping signal_type to Pydantic schema.""" + from pydantic import BaseModel + + class TaskPayload(BaseModel): + task_id: str + description: str + + async def handler(signal, ctx): + pass + + @stream_step( + stream="comms", + signals={"task.created": TaskPayload}, + on_signal=handler, + ) + async def schema_step(): + pass + + meta = get_stream_step("schema_step") + assert meta is not None + assert meta.signal_types == ["task.created"] + assert meta.signal_schemas == {"task.created": TaskPayload} + + def test_stream_step_custom_name(self): + """@stream_step should support custom name.""" + + async def handler(signal, ctx): + pass + + @stream_step( + stream="comms", + signals=["event"], + on_signal=handler, + name="custom_name", + ) + async def original_name(): + pass + + meta = get_stream_step("custom_name") + assert meta is not None + assert meta.name == "custom_name" + + def test_steps_for_stream(self): + """get_steps_for_stream should return steps registered for a stream.""" + + async def handler(signal, ctx): + pass + + @stream_step(stream="shared", signals=["a"], on_signal=handler) + async def step_a(): + pass + + @stream_step(stream="shared", signals=["b"], on_signal=handler) + async def step_b(): + pass + + @stream_step(stream="other", signals=["c"], on_signal=handler) + async def step_c(): + pass + + shared_steps = get_steps_for_stream("shared") + assert len(shared_steps) == 2 + names = [s.name for s in shared_steps] + assert "step_a" in names + assert "step_b" in names + + other_steps = get_steps_for_stream("other") + assert len(other_steps) == 1 + assert other_steps[0].name == "step_c" + + def test_list_all_stream_steps(self): + """list_stream_steps should return all registered steps.""" + + async def handler(signal, ctx): + pass + + @stream_step(stream="s1", signals=["a"], on_signal=handler) + async def sa(): + pass + + @stream_step(stream="s2", signals=["b"], on_signal=handler) + async def sb(): + pass + + all_steps = list_stream_steps() + assert "sa" in all_steps + assert "sb" in all_steps