From 08f3fbd2275eed0f21fa798ba471861c7a0de7a2 Mon Sep 17 00:00:00 2001 From: Yasha Date: Sat, 28 Feb 2026 14:06:03 +0100 Subject: [PATCH 1/3] feat(storage): add CitusStorageBackend for distributed PostgreSQL Adds a new CitusStorageBackend that extends PostgresStorageBackend, reusing all query logic while overriding schema initialization to create Citus-compatible distributed tables (composite PKs, no cross-shard FKs, reference table for schema_versions). Distribution strategy: - workflow_runs, events, steps, hooks, cancellation_flags: sharded on run_id - schedules: sharded on schedule_id - schema_versions: reference table (replicated to all workers) Enables PYWORKFLOW_STORAGE_TYPE=citus using the existing PYWORKFLOW_POSTGRES_* env vars. Co-Authored-By: Claude Sonnet 4.6 --- pyworkflow/config.py | 6 +- pyworkflow/storage/__init__.py | 7 + pyworkflow/storage/citus.py | 374 ++++++++++++++++++++++ pyworkflow/storage/config.py | 25 +- tests/unit/backends/test_citus_storage.py | 286 +++++++++++++++++ 5 files changed, 693 insertions(+), 5 deletions(-) create mode 100644 pyworkflow/storage/citus.py create mode 100644 tests/unit/backends/test_citus_storage.py diff --git a/pyworkflow/config.py b/pyworkflow/config.py index cc953e7..3b06368 100644 --- a/pyworkflow/config.py +++ b/pyworkflow/config.py @@ -18,7 +18,7 @@ ... ) Environment Variables: - PYWORKFLOW_STORAGE_TYPE: Storage backend type (file, memory, sqlite, postgres, mysql) + PYWORKFLOW_STORAGE_TYPE: Storage backend type (file, memory, sqlite, postgres, mysql, citus) PYWORKFLOW_STORAGE_PATH: Path for file/sqlite backends PYWORKFLOW_POSTGRES_HOST: PostgreSQL host PYWORKFLOW_POSTGRES_PORT: PostgreSQL port @@ -58,9 +58,9 @@ def _load_env_storage_config() -> dict[str, Any] | None: storage_type = storage_type.lower() - if storage_type == "postgres": + if storage_type in ("postgres", "citus"): return { - "type": "postgres", + "type": storage_type, "host": os.getenv("PYWORKFLOW_POSTGRES_HOST", "localhost"), "port": int(os.getenv("PYWORKFLOW_POSTGRES_PORT", "5432")), "user": os.getenv("PYWORKFLOW_POSTGRES_USER", "pyworkflow"), diff --git a/pyworkflow/storage/__init__.py b/pyworkflow/storage/__init__.py index 24e1fc5..11ff5df 100644 --- a/pyworkflow/storage/__init__.py +++ b/pyworkflow/storage/__init__.py @@ -47,6 +47,12 @@ except ImportError: MySQLStorageBackend = None # type: ignore +# Citus distributed PostgreSQL backend - optional import (requires asyncpg + Citus extension) +try: + from pyworkflow.storage.citus import CitusStorageBackend +except ImportError: + CitusStorageBackend = None # type: ignore + __all__ = [ "StorageBackend", "FileStorageBackend", @@ -56,6 +62,7 @@ "DynamoDBStorageBackend", "CassandraStorageBackend", "MySQLStorageBackend", + "CitusStorageBackend", "WorkflowRun", "StepExecution", "Hook", diff --git a/pyworkflow/storage/citus.py b/pyworkflow/storage/citus.py new file mode 100644 index 0000000..a2c84fe --- /dev/null +++ b/pyworkflow/storage/citus.py @@ -0,0 +1,374 @@ +""" +Citus distributed PostgreSQL storage backend. + +Citus is a PostgreSQL extension that shards tables across multiple nodes for +horizontal scalability. This backend extends PostgresStorageBackend, reusing all +query logic while overriding schema initialization to create Citus-compatible +distributed tables. + +Distribution strategy: +- All workflow tables are co-located on `run_id` so that all data for a given + workflow run lands on the same Citus shard. +- Schedules are distributed on `schedule_id` (independent of workflow runs). +- `schema_versions` becomes a reference table (replicated to all workers). + +Citus constraints vs plain PostgreSQL: +- Primary keys and unique constraints MUST include the distribution column. +- Cross-shard foreign key constraints are unsupported. +- Global unique constraints on non-distribution columns cannot be enforced. + +See: https://docs.citusdata.com/en/stable/develop/api.html +""" + +import asyncpg + +from pyworkflow.storage.migrations import Migration +from pyworkflow.storage.postgres import PostgresMigrationRunner, PostgresStorageBackend + + +class CitusMigrationRunner(PostgresMigrationRunner): + """ + Citus-specific migration runner. + + Extends PostgresMigrationRunner to handle Citus schema constraints. + V2 migration: same step_id backfill as PostgreSQL, but skips creating + unique constraints that would conflict with Citus distribution requirements. + """ + + async def apply_migration(self, migration: Migration) -> None: + """Apply a migration with Citus-specific handling.""" + from datetime import UTC, datetime + + async with self._pool.acquire() as conn, conn.transaction(): + if migration.version == 2: + # V2: Add step_id column to events table + # Check if events table exists (fresh databases won't have it yet) + table_exists = await conn.fetchrow(""" + SELECT EXISTS ( + SELECT FROM information_schema.tables + WHERE table_name = 'events' + ) as exists + """) + + if table_exists and table_exists["exists"]: + # Use IF NOT EXISTS for idempotency + await conn.execute(""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 FROM information_schema.columns + WHERE table_name = 'events' AND column_name = 'step_id' + ) THEN + ALTER TABLE events ADD COLUMN step_id TEXT; + END IF; + END $$ + """) + + # Create index for optimized has_event() queries + await conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_events_run_id_step_id_type + ON events(run_id, step_id, type) + """) + + # Backfill step_id from JSON data + await conn.execute(""" + UPDATE events + SET step_id = (data::jsonb)->>'step_id' + WHERE step_id IS NULL + AND (data::jsonb)->>'step_id' IS NOT NULL + """) + # NOTE: No UNIQUE constraints added here — Citus requires distribution + # column in every unique constraint, which V2 doesn't add. + elif migration.up_func: + await migration.up_func(conn) + elif migration.up_sql and migration.up_sql != "SELECT 1": + await conn.execute(migration.up_sql) + + # Record the migration + await conn.execute( + """ + INSERT INTO schema_versions (version, applied_at, description) + VALUES ($1, $2, $3) + """, + migration.version, + datetime.now(UTC), + migration.description, + ) + + +class CitusStorageBackend(PostgresStorageBackend): + """ + Citus distributed PostgreSQL storage backend. + + Extends PostgresStorageBackend with Citus-specific schema initialization. + All query logic is inherited unchanged — only DDL differs to satisfy Citus's + distribution column requirements. + + Tables are sharded as follows: + - workflow_runs, events, steps, hooks, cancellation_flags: distributed on run_id + - schedules: distributed on schedule_id + - schema_versions: reference table (replicated to all workers) + + Requirements: + - PostgreSQL with the Citus extension installed and loaded + - The calling database user must have permission to call Citus functions + + Usage: + backend = CitusStorageBackend( + host="citus-coordinator", + database="pyworkflow", + ) + await backend.connect() + """ + + async def _initialize_schema(self) -> None: + """ + Create Citus-compatible schema and distribute tables. + + Steps: + 1. Verify Citus extension is available. + 2. Create tables with Citus-compatible PKs/constraints (no cross-shard FKs, + no cross-shard UNIQUE constraints, composite PKs include distribution col). + 3. Create indexes. + 4. Distribute tables (idempotent — already-distributed tables are skipped). + 5. Run migrations via CitusMigrationRunner. + """ + pool = await self._get_pool() + + # Step 1: Verify Citus extension is present + async with pool.acquire() as conn: + try: + await conn.fetchval("SELECT citus_version()") + except asyncpg.UndefinedFunctionError: + raise RuntimeError( + "Citus extension is not available on this PostgreSQL server. " + "Install Citus and run `CREATE EXTENSION citus;` before using " + "CitusStorageBackend. See: https://docs.citusdata.com/en/stable/installation/" + ) + + # Step 2 & 3: Create tables and indexes with Citus-compatible DDL + async with pool.acquire() as conn: + # schema_versions — created first so CitusMigrationRunner can use it + await conn.execute(""" + CREATE TABLE IF NOT EXISTS schema_versions ( + version INTEGER PRIMARY KEY, + applied_at TIMESTAMPTZ NOT NULL, + description TEXT + ) + """) + + # workflow_runs — distributed on run_id + # Differences vs postgres: + # - DROP FK on parent_run_id (self-referential FK can't be enforced cross-shard) + # - UNIQUE INDEX on idempotency_key → plain INDEX (global uniqueness unenforceable) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS workflow_runs ( + run_id TEXT PRIMARY KEY, + workflow_name TEXT NOT NULL, + status TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL, + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ, + input_args TEXT NOT NULL DEFAULT '[]', + input_kwargs TEXT NOT NULL DEFAULT '{}', + result TEXT, + error TEXT, + idempotency_key TEXT, + max_duration TEXT, + metadata TEXT DEFAULT '{}', + recovery_attempts INTEGER DEFAULT 0, + max_recovery_attempts INTEGER DEFAULT 3, + recover_on_worker_loss BOOLEAN DEFAULT TRUE, + parent_run_id TEXT, + nesting_depth INTEGER DEFAULT 0, + continued_from_run_id TEXT, + continued_to_run_id TEXT + ) + """) + + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_status ON workflow_runs(status)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_workflow_name ON workflow_runs(workflow_name)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_created_at ON workflow_runs(created_at DESC)" + ) + # Non-unique index: Citus cannot enforce global uniqueness on non-distribution columns + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_idempotency_key ON workflow_runs(idempotency_key) WHERE idempotency_key IS NOT NULL" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_runs_parent_run_id ON workflow_runs(parent_run_id)" + ) + + # events — distributed on run_id, co-located with workflow_runs + # Differences vs postgres: + # - PK: event_id → (run_id, event_id) to include distribution column + # - FK on run_id retained (co-located FK is supported by Citus) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS events ( + event_id TEXT NOT NULL, + run_id TEXT NOT NULL, + sequence INTEGER NOT NULL, + type TEXT NOT NULL, + timestamp TIMESTAMPTZ NOT NULL, + data TEXT NOT NULL DEFAULT '{}', + step_id TEXT, + PRIMARY KEY (run_id, event_id) + ) + """) + + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_events_run_id_sequence ON events(run_id, sequence)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_events_run_id_type ON events(run_id, type)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_events_run_id_step_id_type ON events(run_id, step_id, type)" + ) + + # steps — distributed on run_id, co-located with workflow_runs + # Differences vs postgres: + # - PK: step_id → (run_id, step_id) to include distribution column + # - FK on run_id retained (co-located FK) + # - Extra index on step_id alone for get_step(step_id) scatter-gather queries + await conn.execute(""" + CREATE TABLE IF NOT EXISTS steps ( + step_id TEXT NOT NULL, + run_id TEXT NOT NULL, + step_name TEXT NOT NULL, + status TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + started_at TIMESTAMPTZ, + completed_at TIMESTAMPTZ, + input_args TEXT NOT NULL DEFAULT '[]', + input_kwargs TEXT NOT NULL DEFAULT '{}', + result TEXT, + error TEXT, + retry_count INTEGER DEFAULT 0, + PRIMARY KEY (run_id, step_id) + ) + """) + + await conn.execute("CREATE INDEX IF NOT EXISTS idx_steps_run_id ON steps(run_id)") + # Shard-local index to speed up get_step(step_id) scatter-gather queries + await conn.execute("CREATE INDEX IF NOT EXISTS idx_steps_step_id ON steps(step_id)") + + # hooks — distributed on run_id, co-located with workflow_runs + # Differences vs postgres: + # - UNIQUE on token → plain INDEX (token = run_id:hook_id, collisions impossible) + # - FK on run_id retained (co-located FK) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS hooks ( + run_id TEXT NOT NULL, + hook_id TEXT NOT NULL, + token TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + received_at TIMESTAMPTZ, + expires_at TIMESTAMPTZ, + status TEXT NOT NULL, + payload TEXT, + metadata TEXT DEFAULT '{}', + PRIMARY KEY (run_id, hook_id) + ) + """) + + await conn.execute("CREATE INDEX IF NOT EXISTS idx_hooks_token ON hooks(token)") + await conn.execute("CREATE INDEX IF NOT EXISTS idx_hooks_run_id ON hooks(run_id)") + await conn.execute("CREATE INDEX IF NOT EXISTS idx_hooks_status ON hooks(status)") + + # schedules — distributed on schedule_id (independent of runs) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS schedules ( + schedule_id TEXT PRIMARY KEY, + workflow_name TEXT NOT NULL, + spec TEXT NOT NULL, + spec_type TEXT NOT NULL, + timezone TEXT, + input_args TEXT NOT NULL DEFAULT '[]', + input_kwargs TEXT NOT NULL DEFAULT '{}', + status TEXT NOT NULL, + overlap_policy TEXT NOT NULL, + next_run_time TIMESTAMPTZ, + last_run_time TIMESTAMPTZ, + running_run_ids TEXT DEFAULT '[]', + metadata TEXT DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL, + paused_at TIMESTAMPTZ, + deleted_at TIMESTAMPTZ + ) + """) + + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_schedules_status ON schedules(status)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_schedules_next_run_time ON schedules(next_run_time)" + ) + await conn.execute( + "CREATE INDEX IF NOT EXISTS idx_schedules_workflow_name ON schedules(workflow_name)" + ) + + # cancellation_flags — distributed on run_id, co-located with workflow_runs + # Differences vs postgres: + # - FK on run_id dropped (cross-shard FK unsupported; co-location ensures correctness) + await conn.execute(""" + CREATE TABLE IF NOT EXISTS cancellation_flags ( + run_id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL + ) + """) + + # Step 4: Distribute tables (idempotent — already-distributed tables are skipped) + await self._distribute_tables(pool) + + # Step 5: Run migrations via Citus-aware runner + runner = CitusMigrationRunner(pool) + await runner.run_migrations() + + async def _distribute_tables(self, pool: asyncpg.Pool) -> None: + """ + Call Citus distribution functions for each table. + + Uses pg_dist_partition to detect already-distributed tables so this + method is idempotent and safe to call on every connect(). + """ + async with pool.acquire() as conn: + # Fetch already-distributed table names using relname to avoid + # schema-prefix ambiguity (logicalrelid::text is search_path-dependent) + rows = await conn.fetch(""" + SELECT c.relname AS tbl + FROM pg_dist_partition dp + JOIN pg_class c ON dp.logicalrelid = c.oid + """) + distributed = {row["tbl"] for row in rows} + + # workflow_runs: anchor table, distribute first + if "workflow_runs" not in distributed: + await conn.execute( + "SELECT create_distributed_table('workflow_runs', 'run_id')" + ) + + # Co-located tables: must reference the same distribution column + colocated_run_id = ["events", "steps", "hooks", "cancellation_flags"] + for table in colocated_run_id: + if table not in distributed: + await conn.execute( + f"SELECT create_distributed_table('{table}', 'run_id', " + f"colocate_with => 'workflow_runs')" + ) + + # schedules: independent distribution + if "schedules" not in distributed: + await conn.execute( + "SELECT create_distributed_table('schedules', 'schedule_id')" + ) + + # schema_versions: reference table (replicated to all workers) + if "schema_versions" not in distributed: + await conn.execute("SELECT create_reference_table('schema_versions')") diff --git a/pyworkflow/storage/config.py b/pyworkflow/storage/config.py index 83ef3b5..de0199c 100644 --- a/pyworkflow/storage/config.py +++ b/pyworkflow/storage/config.py @@ -84,8 +84,9 @@ def storage_to_config(storage: StorageBackend | None) -> dict[str, Any] | None: "port": getattr(storage, "port", 6379), "db": getattr(storage, "db", 0), } - elif class_name == "PostgresStorageBackend": - config: dict[str, Any] = {"type": "postgres"} + elif class_name in ("PostgresStorageBackend", "CitusStorageBackend"): + storage_type = "citus" if class_name == "CitusStorageBackend" else "postgres" + config: dict[str, Any] = {"type": storage_type} dsn = getattr(storage, "dsn", None) if dsn: config["dsn"] = dsn @@ -247,6 +248,26 @@ def _create_storage_backend(config: dict[str, Any] | None) -> StorageBackend: database=config.get("database", "pyworkflow"), ) + elif storage_type == "citus": + try: + from pyworkflow.storage.citus import CitusStorageBackend + except ImportError: + raise ValueError( + "Citus storage backend is not available. Install asyncpg: pip install asyncpg" + ) + + # Support both DSN and individual parameters + if "dsn" in config: + return CitusStorageBackend(dsn=config["dsn"]) + else: + return CitusStorageBackend( + host=config.get("host", "localhost"), + port=config.get("port", 5432), + user=config.get("user", "pyworkflow"), + password=config.get("password", ""), + database=config.get("database", "pyworkflow"), + ) + elif storage_type == "dynamodb": try: from pyworkflow.storage.dynamodb import DynamoDBStorageBackend diff --git a/tests/unit/backends/test_citus_storage.py b/tests/unit/backends/test_citus_storage.py new file mode 100644 index 0000000..cd9ce34 --- /dev/null +++ b/tests/unit/backends/test_citus_storage.py @@ -0,0 +1,286 @@ +""" +Unit tests for Citus distributed PostgreSQL storage backend. + +These tests verify CitusStorageBackend initialization, config round-trips, +and env-var loading. They do NOT require a live Citus/PostgreSQL instance. +For integration tests with a real Citus cluster, see tests/integration/. +""" + +import os +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Skip all tests if asyncpg is not installed +pytest.importorskip("asyncpg") + +from pyworkflow.storage.citus import CitusStorageBackend, CitusMigrationRunner +from pyworkflow.storage.config import _create_storage_backend, config_to_storage, storage_to_config + + +@pytest.fixture +def mock_citus_backend(): + """Create a CitusStorageBackend with a mocked pool for testing.""" + backend = CitusStorageBackend() + mock_pool = MagicMock() + mock_conn = AsyncMock() + + @asynccontextmanager + async def mock_acquire(): + yield mock_conn + + mock_pool.acquire = mock_acquire + backend._pool = mock_pool + return backend, mock_conn + + +class TestCitusStorageBackendInit: + """Test CitusStorageBackend initialization.""" + + def test_init_defaults(self): + """Test default initialization parameters.""" + backend = CitusStorageBackend() + assert backend.dsn is None + assert backend.host == "localhost" + assert backend.port == 5432 + assert backend.user == "pyworkflow" + assert backend.password == "" + assert backend.database == "pyworkflow" + assert backend._pool is None + assert backend._initialized is False + + def test_init_with_dsn(self): + """Test initialization with a DSN connection string.""" + dsn = "postgresql://user:pass@citus-coordinator:5432/pyworkflow" + backend = CitusStorageBackend(dsn=dsn) + assert backend.dsn == dsn + assert backend._pool is None + + def test_init_with_individual_params(self): + """Test initialization with individual connection parameters.""" + backend = CitusStorageBackend( + host="citus-coordinator", + port=5433, + user="citususer", + password="secret", + database="wf", + ) + assert backend.dsn is None + assert backend.host == "citus-coordinator" + assert backend.port == 5433 + assert backend.user == "citususer" + assert backend.password == "secret" + assert backend.database == "wf" + + def test_inherits_postgres_backend(self): + """CitusStorageBackend must inherit from PostgresStorageBackend.""" + from pyworkflow.storage.postgres import PostgresStorageBackend + + backend = CitusStorageBackend() + assert isinstance(backend, PostgresStorageBackend) + + def test_build_dsn(self): + """Test DSN construction from individual params.""" + backend = CitusStorageBackend( + host="coordinator", + port=5432, + user="admin", + password="pass", + database="db", + ) + dsn = backend._build_dsn() + assert "coordinator" in dsn + assert "admin" in dsn + assert "db" in dsn + + +class TestCitusMigrationRunner: + """Test CitusMigrationRunner.""" + + def test_inherits_postgres_runner(self): + """CitusMigrationRunner must inherit from PostgresMigrationRunner.""" + from pyworkflow.storage.postgres import PostgresMigrationRunner + + pool = MagicMock() + runner = CitusMigrationRunner(pool) + assert isinstance(runner, PostgresMigrationRunner) + + +class TestStorageToConfig: + """Test storage_to_config() serialization for CitusStorageBackend.""" + + def test_citus_backend_produces_citus_type(self): + """storage_to_config returns type='citus' for CitusStorageBackend.""" + backend = CitusStorageBackend( + host="citus-host", + port=5432, + user="user", + password="pass", + database="db", + ) + config = storage_to_config(backend) + assert config is not None + assert config["type"] == "citus" + + def test_citus_backend_with_dsn(self): + """storage_to_config preserves DSN for CitusStorageBackend.""" + dsn = "postgresql://user:pass@coordinator/db" + backend = CitusStorageBackend(dsn=dsn) + config = storage_to_config(backend) + assert config is not None + assert config["type"] == "citus" + assert config["dsn"] == dsn + + def test_citus_config_round_trip_individual_params(self): + """Round-trip: CitusStorageBackend → config dict → CitusStorageBackend.""" + original = CitusStorageBackend( + host="citus-host", + port=5432, + user="wf", + password="secret", + database="wfdb", + ) + config = storage_to_config(original) + assert config is not None + assert config["type"] == "citus" + assert config["host"] == "citus-host" + assert config["user"] == "wf" + assert config["database"] == "wfdb" + + +class TestConfigToStorage: + """Test config_to_storage() / _create_storage_backend() for citus type.""" + + def test_citus_type_returns_citus_backend(self): + """_create_storage_backend({'type': 'citus', ...}) returns CitusStorageBackend.""" + config = { + "type": "citus", + "host": "coordinator", + "port": 5432, + "user": "user", + "password": "", + "database": "pyworkflow", + } + backend = _create_storage_backend(config) + assert isinstance(backend, CitusStorageBackend) + + def test_citus_type_with_dsn(self): + """_create_storage_backend with DSN returns CitusStorageBackend.""" + config = { + "type": "citus", + "dsn": "postgresql://user@coordinator/db", + } + backend = _create_storage_backend(config) + assert isinstance(backend, CitusStorageBackend) + assert backend.dsn == "postgresql://user@coordinator/db" + + def test_config_to_storage_caches_instance(self): + """config_to_storage() returns the same cached instance for identical config.""" + from pyworkflow.storage.config import clear_storage_cache + + clear_storage_cache() + config = { + "type": "citus", + "host": "coordinator", + "port": 5432, + "user": "user", + "password": "", + "database": "pyworkflow", + } + backend1 = config_to_storage(config) + backend2 = config_to_storage(config) + assert backend1 is backend2 + clear_storage_cache() + + def test_citus_config_full_round_trip(self): + """Full round-trip: CitusStorageBackend → config → CitusStorageBackend.""" + from pyworkflow.storage.config import clear_storage_cache + + clear_storage_cache() + original = CitusStorageBackend( + host="citus-coordinator", + port=5432, + user="wf", + password="secret", + database="wfdb", + ) + config = storage_to_config(original) + assert config is not None + + clear_storage_cache() + restored = _create_storage_backend(config) + assert isinstance(restored, CitusStorageBackend) + assert restored.host == "citus-coordinator" + assert restored.user == "wf" + assert restored.database == "wfdb" + clear_storage_cache() + + +class TestEnvVarConfig: + """Test PYWORKFLOW_STORAGE_TYPE=citus env var loading.""" + + def test_citus_env_var_produces_citus_config(self): + """PYWORKFLOW_STORAGE_TYPE=citus reads PYWORKFLOW_POSTGRES_* vars.""" + from pyworkflow.config import _load_env_storage_config + + env = { + "PYWORKFLOW_STORAGE_TYPE": "citus", + "PYWORKFLOW_POSTGRES_HOST": "citus-host", + "PYWORKFLOW_POSTGRES_PORT": "5433", + "PYWORKFLOW_POSTGRES_USER": "citususer", + "PYWORKFLOW_POSTGRES_PASSWORD": "secret", + "PYWORKFLOW_POSTGRES_DATABASE": "citusdb", + } + with patch.dict(os.environ, env, clear=False): + config = _load_env_storage_config() + + assert config is not None + assert config["type"] == "citus" + assert config["host"] == "citus-host" + assert config["port"] == 5433 + assert config["user"] == "citususer" + assert config["password"] == "secret" + assert config["database"] == "citusdb" + + def test_citus_env_var_defaults(self): + """PYWORKFLOW_STORAGE_TYPE=citus uses default Postgres vars when not set.""" + from pyworkflow.config import _load_env_storage_config + + # Only set the storage type; rely on defaults for the rest + env_overrides = { + "PYWORKFLOW_STORAGE_TYPE": "citus", + } + # Ensure Postgres vars are not set + remove_keys = [ + "PYWORKFLOW_POSTGRES_HOST", + "PYWORKFLOW_POSTGRES_PORT", + "PYWORKFLOW_POSTGRES_USER", + "PYWORKFLOW_POSTGRES_PASSWORD", + "PYWORKFLOW_POSTGRES_DATABASE", + ] + env = {k: v for k, v in os.environ.items() if k not in remove_keys} + env.update(env_overrides) + + with patch.dict(os.environ, env, clear=True): + config = _load_env_storage_config() + + assert config is not None + assert config["type"] == "citus" + assert config["host"] == "localhost" + assert config["port"] == 5432 + assert config["user"] == "pyworkflow" + assert config["database"] == "pyworkflow" + + def test_citus_env_var_creates_citus_backend(self): + """End-to-end: PYWORKFLOW_STORAGE_TYPE=citus → CitusStorageBackend instance.""" + from pyworkflow.config import _load_env_storage_config + from pyworkflow.storage.config import _create_storage_backend + + env = {"PYWORKFLOW_STORAGE_TYPE": "citus"} + with patch.dict(os.environ, env, clear=False): + storage_config = _load_env_storage_config() + + assert storage_config is not None + backend = _create_storage_backend(storage_config) + assert isinstance(backend, CitusStorageBackend) From 78208da790d2a70ac5b22caf20b425707011b33f Mon Sep 17 00:00:00 2001 From: Yasha Date: Sun, 1 Mar 2026 12:38:21 +0100 Subject: [PATCH 2/3] fix(storage): sort imports in test_citus_storage to satisfy ruff I001 Co-Authored-By: Claude Sonnet 4.6 --- tests/unit/backends/test_citus_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/backends/test_citus_storage.py b/tests/unit/backends/test_citus_storage.py index cd9ce34..8dfceb4 100644 --- a/tests/unit/backends/test_citus_storage.py +++ b/tests/unit/backends/test_citus_storage.py @@ -15,7 +15,7 @@ # Skip all tests if asyncpg is not installed pytest.importorskip("asyncpg") -from pyworkflow.storage.citus import CitusStorageBackend, CitusMigrationRunner +from pyworkflow.storage.citus import CitusMigrationRunner, CitusStorageBackend from pyworkflow.storage.config import _create_storage_backend, config_to_storage, storage_to_config From ffcdf02d9b82e29576f6eb01197b333d77060a90 Mon Sep 17 00:00:00 2001 From: Yasha Date: Sun, 1 Mar 2026 12:41:12 +0100 Subject: [PATCH 3/3] style(storage): apply ruff format to citus.py Co-Authored-By: Claude Sonnet 4.6 --- pyworkflow/storage/citus.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pyworkflow/storage/citus.py b/pyworkflow/storage/citus.py index a2c84fe..d1afb1c 100644 --- a/pyworkflow/storage/citus.py +++ b/pyworkflow/storage/citus.py @@ -350,9 +350,7 @@ async def _distribute_tables(self, pool: asyncpg.Pool) -> None: # workflow_runs: anchor table, distribute first if "workflow_runs" not in distributed: - await conn.execute( - "SELECT create_distributed_table('workflow_runs', 'run_id')" - ) + await conn.execute("SELECT create_distributed_table('workflow_runs', 'run_id')") # Co-located tables: must reference the same distribution column colocated_run_id = ["events", "steps", "hooks", "cancellation_flags"] @@ -365,9 +363,7 @@ async def _distribute_tables(self, pool: asyncpg.Pool) -> None: # schedules: independent distribution if "schedules" not in distributed: - await conn.execute( - "SELECT create_distributed_table('schedules', 'schedule_id')" - ) + await conn.execute("SELECT create_distributed_table('schedules', 'schedule_id')") # schema_versions: reference table (replicated to all workers) if "schema_versions" not in distributed: