diff --git a/api/openapi.yaml b/api/openapi.yaml index fd3cd44..7902917 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -960,6 +960,63 @@ paths: '503': description: Forwarded from HF — upstream unavailable. + /source-proxy/subtask/{subtaskId}: + get: + tags: [executors] + summary: Multi-source reverse proxy — stream a subtask's file from its assigned source + description: >- + Controller-side reverse proxy. Authenticates the executor (mTLS + JWT), + verifies subtask ownership (assignment_token + epoch fence), resolves the + planner-assigned source for this subtask (defaults to huggingface when + unassigned), injects the appropriate credential, and streams the bytes + back. The executor never holds source credentials directly (INVARIANT 2). + operationId: sourceProxySubtask + parameters: + - name: subtaskId + in: path + required: true + schema: + type: string + format: uuid + - name: X-Assignment-Token + in: header + required: true + schema: + type: string + format: uuid + - name: Range + in: header + required: false + schema: + type: string + responses: + '200': + description: Full file stream. + content: + application/octet-stream: + schema: + type: string + format: binary + '206': + description: Partial content (Range request). + content: + application/octet-stream: + schema: + type: string + format: binary + '401': + description: Missing or invalid mTLS / executor JWT. + '403': + description: NOT_YOUR_SUBTASK — authenticated executor does not own this subtask. + '404': + description: Subtask not found. + '409': + description: STALE_ASSIGNMENT or EPOCH_MISMATCH (fence-token mismatch). + '502': + description: SOURCE_UNAVAILABLE — source returned an unrecoverable error. + '503': + description: Source unreachable or all sources exhausted. + /executors/{executorId}/subtasks/{subtaskId}/complete: parameters: - in: path diff --git a/config/resolver-rules.yaml b/config/resolver-rules.yaml new file mode 100644 index 0000000..65fb254 --- /dev/null +++ b/config/resolver-rules.yaml @@ -0,0 +1,12 @@ +identity_organizations: + - deepseek-ai + - Qwen + - 01-ai + - THUDM + - baichuan-inc + - mistralai +aliases: + - hf_org: meta-llama + modelscope_org: LLM-Research + transform: "Meta-{name}" +per_model_overrides: [] diff --git a/config/sources.yaml b/config/sources.yaml new file mode 100644 index 0000000..2b35551 --- /dev/null +++ b/config/sources.yaml @@ -0,0 +1,22 @@ +sources: + - id: huggingface + enabled: true + driver: huggingface + config: {base_url: "https://huggingface.co", timeout_seconds: 30} + cost_per_gb_egress: 0.09 + - id: hf_mirror + enabled: true + driver: hf_mirror + config: {base_url: "https://hf-mirror.com", timeout_seconds: 30} + cost_per_gb_egress: 0.0 + - id: modelscope + enabled: true + driver: modelscope + config: {base_url: "https://www.modelscope.cn", timeout_seconds: 30} + cost_per_gb_egress: 0.0 +balancing: + speed_ewma_alpha: 0.3 + chunk_level_min_file_mb: 100 +regional_defaults: + cn-north: ["hf_mirror", "modelscope", "huggingface"] + us-east: ["huggingface"] diff --git a/docs/operator/multi-source.md b/docs/operator/multi-source.md new file mode 100644 index 0000000..faf5bbc --- /dev/null +++ b/docs/operator/multi-source.md @@ -0,0 +1,196 @@ +# Multi-Source Download — Operator Guide (SP2) + +> **Cross-references**: `docs/v2.0/06-platform-and-ecosystem.md` §1 (design rationale, +> scheduling algorithm, name-resolver detail); `docs/v2.0/INVARIANTS.md` 11/12/13 +> (SHA256 authority, cross-source verification, HF-down policy). + +--- + +## 1. `config/sources.yaml` + +Controls which download sources the controller loads at startup. + +```yaml +sources: + - id: huggingface # unique identifier used in source_blacklist, logs, etc. + enabled: true + driver: huggingface # must be a supported driver (see §1.1) + config: + base_url: "https://huggingface.co" + timeout_seconds: 30 + cost_per_gb_egress: 0.09 # USD; used by cost-aware scheduling (future) + + - id: hf_mirror + enabled: true + driver: hf_mirror + config: + base_url: "https://hf-mirror.com" + timeout_seconds: 30 + cost_per_gb_egress: 0.0 + + - id: modelscope + enabled: true + driver: modelscope + config: + base_url: "https://www.modelscope.cn" + timeout_seconds: 30 + cost_per_gb_egress: 0.0 + +balancing: + speed_ewma_alpha: 0.3 # EWMA smoothing factor for speed samples + chunk_level_min_file_mb: 100 # files smaller than this are not chunk-routed + +regional_defaults: + cn-north: ["hf_mirror", "modelscope", "huggingface"] + us-east: ["huggingface"] +``` + +### 1.1 Supported drivers (v2.0) + +| Driver ID | Endpoint | Notes | +|---------------|-----------------------|------------------------------| +| `huggingface` | huggingface.co | Authoritative SHA256 source | +| `hf_mirror` | hf-mirror.com | HF-compatible reverse proxy | +| `modelscope` | modelscope.cn | Requires name-resolver rules | + +Entries with an unrecognised driver are logged as a warning and skipped at startup. + +Drivers deferred to v2.1: `wisemodel`, `opencsg`, `s3_mirror`, `plugin` +(see §6 Deferred items). + +--- + +## 2. `config/resolver-rules.yaml` + +Maps HuggingFace repo IDs to their equivalents on other sources. + +```yaml +identity_organizations: + # Organizations whose repo IDs are identical across all sources. + - deepseek-ai + - Qwen + - 01-ai + - THUDM + - baichuan-inc + - mistralai + +aliases: + # Transform rules for organizations with different naming on other sources. + - hf_org: meta-llama + modelscope_org: LLM-Research + transform: "Meta-{name}" # e.g. Llama-3.1-8B → Meta-Llama-3.1-8B + +per_model_overrides: + # Exact per-model overrides; takes precedence over aliases. + # - hf: "specific-org/specific-model" + # modelscope: "different-org/different-name" +``` + +The resolver applies three-tier lookup: (1) identity match, (2) alias rule, +(3) source search API fallback (cached 24 h). If no mapping is found the +source is skipped for that task. + +--- + +## 3. SP2 Environment Settings (`DLW_*`) + +All settings live under `SourceSettings` in `src/dlw/config.py` and are read +from environment variables with the `DLW_` prefix (or from the Helm configmap). + +| Setting | Default | Description | +|-----------------------------------|---------|--------------------------------------------------------------------------| +| `DLW_PROBE_SIZE_MB` | 32 | Bytes downloaded per source during the scheduling-phase speed probe | +| `DLW_PROBE_TIMEOUT_S` | 8.0 | Soft deadline (seconds) for each probe; partial bytes still recorded | +| `DLW_PROBE_HISTORY_WEIGHT` | 0.3 | EWMA history weight; live probe weight = `1 - probe_history_weight` | +| `DLW_CHUNK_LEVEL_MIN_FILE_MB` | 100 | Files smaller than this are not split across sources | +| `DLW_SHA_MISMATCH_BLACKLIST_HOURS`| 24 | Duration to blacklist a `(source, repo, filename)` after SHA256 mismatch | +| `DLW_REBALANCE_INTERVAL_SECONDS` | 60.0 | How often the background rebalancer re-evaluates in-flight task routing | + +Tuning guidance: increase `PROBE_SIZE_MB` to 64 for large-model repos where +speed variance is high; reduce `PROBE_TIMEOUT_S` below 8 only on low-latency +networks where probes consistently finish within 3-4 s. + +--- + +## 4. `source_strategy` Task Field + +Set on `POST /api/v1/tasks` in the `source_strategy` field. + +| Value | Behaviour | +|------------------|--------------------------------------------------------------------------| +| `auto_balance` | Default. Probe all enabled sources, allocate files/chunks by speed. | +| `fastest_only` | Probe all sources, use only the single fastest. | +| `pin_huggingface`| Skip probe; download everything from HuggingFace only. | +| `pin_modelscope` | Skip probe; use ModelScope only (resolver rules applied automatically). | +| `list:a,b` | Use only the listed source IDs (comma-separated); probe between them. | + +Sources listed in `source_blacklist` (array of source IDs) are always excluded +regardless of strategy. + +--- + +## 5. SHA256 Authority Rules (INVARIANTS 11/12/13) + +### INVARIANT 11 — HF is the authoritative SHA256 source + +All files downloaded from any source must be verified against the SHA256 value +that HuggingFace provides in its LFS manifest. No other source's self-reported +SHA256 is accepted as truth. + +### INVARIANT 12 — Cross-source verification is mandatory + +After completing a download (single-source or chunk-level multi-source), the +controller compares the actual file SHA256 against the HF-supplied value. +A mismatch triggers a `(source_id, repo_id, filename)` blacklist for +`sha_mismatch_blacklist_hours` (default 24 h). Subsequent subtasks for that +combination fall back to HuggingFace. + +### INVARIANT 13 — HF unavailable → task paused unless `trust_non_hf_sha256` + +When HuggingFace is unreachable and the task was created with the default +`trust_non_hf_sha256: false`: + +- The task transitions to `paused_external` with error code `no_sha256_authority`. +- No bytes are downloaded from alternative sources because integrity cannot be + guaranteed. + +Set `trust_non_hf_sha256: true` on the task to opt out of this guarantee and +allow downloads to proceed using other sources' self-reported checksums. + +**Special case**: if a file has no SHA256 pinned in HF's manifest at all (rare, +typically raw text files), it is always routed exclusively through HuggingFace +regardless of strategy. + +--- + +## 6. Scheduling and Rebalancing — Leader-Gated + +The scheduling phase (task `pending` → `scheduling` → `downloading`) and the +background rebalancer both run exclusively on the **active controller** (the +current Raft/leader-election winner). Standby controllers do not run probes or +mutate source assignments. + +Task state `scheduling` is transient: it covers the probe window +(`probe_timeout_s`) plus assignment computation. If the controller loses +leadership during scheduling the task reverts to `pending` and is picked up by +the new leader. + +The rebalancer (interval: `rebalance_interval_seconds`) re-probes sources for +tasks whose in-flight speed deviates significantly from the initial probe, and +may reassign future chunks. It does not interrupt chunks already in flight. + +Note: per-executor probing (where each executor independently probes its local +sources) is deferred to v2.1. + +--- + +## 7. Deferred to v2.1 + +The following capabilities are scoped out of v2.0 and will ship in v2.1: + +- `wisemodel` and `opencsg` source drivers +- `s3_mirror` driver and per-task `s3_direct_source` (schema reserved) +- Plugin-based source driver API +- Per-executor probing (executors report their own speed matrix independently) +- Automatic 5xx / health-check triggered source blacklist transitions +- Source cost accounting UI and budget enforcement diff --git a/docs/superpowers/plans/2026-05-19-phase-3-sp2-multi-source.md b/docs/superpowers/plans/2026-05-19-phase-3-sp2-multi-source.md new file mode 100644 index 0000000..e1a500a --- /dev/null +++ b/docs/superpowers/plans/2026-05-19-phase-3-sp2-multi-source.md @@ -0,0 +1,2510 @@ +# Phase 3 SP2 — Multi-Source Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax. + +**Goal:** Download a task from multiple mirror sources (HF / hf-mirror / ModelScope) in parallel, picking the fastest combination per fleet, with HuggingFace as the cryptographic source of truth. + +**Architecture:** A `SourceDriver` Protocol (`dlw/sources/`) with 3 drivers behind a `sources.yaml` registry + a `NameResolver`; a leader-gated scheduling loop runs `plan_task_sources` (resolve → speed-probe → LPT/chunk plan → persist `file_subtasks.source_id`/`subtask_chunks`); a generalized `/api/v1/source-proxy` streams each subtask/chunk from its assigned source's driver with controller-side creds; HF sha256 authority + 24h blacklist on mismatch; a minimal leader-gated rebalance loop reassigns a degraded source's pending chunks. + +**Tech Stack:** FastAPI + SQLAlchemy 2 async + asyncpg, `huggingface_hub` (HF + hf_mirror), raw `httpx` (ModelScope/proxy/probe), `pyyaml` (new), SP1's `Principal`/`require_perm`/`tenant_filtered`/casbin, leader-gated-loop pattern from SP1 `_quota_loop`. + +**Spec:** `docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md`. **Branch:** `feat/phase-3-sp2-multi-source` (off `main` @ `fa08e6d`, spec committed `ccdb9e8`). + +**Conventions (verified against the codebase — follow exactly):** +- Bash tool errors on `cd ` (Windows); working dir is already `D:\download_weights`. Use `uv run pytest ...` / `uv run alembic ...`. +- ORM models register in `src/dlw/db/models/__init__.py` (imports + sorted `__all__`). **Never** add model imports to `src/dlw/db/base.py` (circular import; tables won't create). +- `tools/lint_invariants.py` `VALID_TASK_STATUS` ALREADY contains `"scheduling"` (line 95) — no edit needed. It AST-scans only `src/dlw/api/tasks.py`, `services/task_service.py`, `services/scheduler.py` for task/subtask status literals. **Chunk-status literals (`pending|downloading|done|failed`) MUST live only in `source_scheduler.py`/`source_proxy.py`/`source_blacklist.py` (NOT scanned). Never put a chunk-status string literal into the 3 scanned files.** +- DB tests: `@pytest.mark.slow`, use the session `engine` fixture; **every DB fixture does `drop_all` → `create_all`** (session-DB collision avoidance — SP1 lesson) and is function-scoped unless module sharing is safe. `import dlw.db.models # noqa: F401` before any `Base.metadata.create_all` so all tables register. +- New test dirs need an empty `__init__.py` (`tests/sources/` is new). +- API tests build the app via `tests.conftest.make_app_with_state(ephemeral_ca, enrollment_token="e")` (seeds `app.state.settings`/`casbin`; SP2 extends it to also seed `source_registry`/`name_resolver`). System-JWT auth via `tests.conftest.principal_headers(secret="unit-secret", role="tenant_admin")` with an autouse fixture setting `DLW_SYSTEM_JWT_SECRET="unit-secret"` + `get_settings.cache_clear()`. +- Real CI gates (no ruff/mypy/code-vs-yaml CI): `pytest` (`uv sync --all-groups`, uv 0.11.9), `invariant_lint` (`uv run python -m pytest tools/test_lint_invariants.py` + `python tools/lint_invariants.py` + `python tools/lint_no_direct_status_write.py`), `openapi` (`spectral lint api/openapi.yaml --fail-severity=error` + `swagger-cli validate api/openapi.yaml`), `yamllint` (`deploy/ api/` — `config/*.yaml` NOT scanned but keep valid). New deps → edit `pyproject.toml`, `uv lock`, commit `uv.lock`. +- Service layer does NOT commit; caller commits (matches scheduler/quota). Run only each task's named tests; the full suite goes red between tasks until M4 wiring lands (expected) — controller runs the milestone E2E. + +--- + +## File Structure + +| File | Responsibility | +|---|---| +| `src/dlw/sources/__init__.py` (new) | package marker | +| `src/dlw/sources/base.py` (new) | `SourceDriver` Protocol + `SourceManifest`/`SourceFile`/`SourceHealth`/`SourceToken` | +| `src/dlw/sources/huggingface.py` (new) | HF driver (wraps `hf_metadata`) | +| `src/dlw/sources/hf_mirror.py` (new) | hf-mirror driver (HF-compat, no token, gated→skip) | +| `src/dlw/sources/modelscope.py` (new) | ModelScope driver (own API, no sha256) | +| `src/dlw/sources/registry.py` (new) | parse `sources.yaml`, build enabled `{id: driver}` | +| `src/dlw/sources/name_resolver.py` (new) | 3-tier name mapping from `resolver-rules.yaml` | +| `config/sources.yaml`, `config/resolver-rules.yaml` (new) | source + resolver config | +| `src/dlw/db/models/source.py` (new) | `SubtaskChunk`, `SourceSpeedSample`, `SourceBlacklist` | +| `src/dlw/services/source_speed.py` (new) | probe matrix + EWMA fusion | +| `src/dlw/services/source_combo.py` (new) | LPT assign + optimal-combo | +| `src/dlw/services/source_scheduler.py` (new) | `plan_task_sources` (resolve→probe→assign→persist) | +| `src/dlw/services/source_blacklist.py` (new) | blacklist transitions/queries | +| `src/dlw/api/source_proxy.py` (new) | `/api/v1/source-proxy/subtask/{id}` | +| `src/dlw/alembic/versions/_p3sp2_multi_source.py` (new) | cols + 3 tables | +| `src/dlw/executor/client.py` (modify) | add `stream_source` | +| `src/dlw/executor/chunk_downloader.py` (modify) | use `stream_source` | +| `src/dlw/services/scheduler.py` (modify) | sha256-authority gate on report | +| `src/dlw/services/task_service.py` (modify) | accept new TaskCreate fields | +| `src/dlw/schemas/task.py` (modify) | `source_strategy`/`source_blacklist`/`trust_non_hf_sha256` | +| `src/dlw/main.py` (modify) | lifespan registry/resolver + `_scheduling_loop` + `_rebalance_loop` | +| `src/dlw/config.py` (modify) | SP2 settings | +| `tests/conftest.py` (modify) | extend `make_app_with_state` (registry/resolver) | +| `docs/operator/multi-source.md` (new) | operator guide | + +--- + +# Milestone M1 — Source Layer + +### Task 1: Config additions + `pyyaml` dep + +**Files:** Modify `src/dlw/config.py`, `pyproject.toml`; Test (extend) `tests/test_config.py` + +- [ ] **Step 1: Write the failing test** — append to `tests/test_config.py`: +```python +def test_sp2_source_settings_defaults(): + from dlw.config import get_settings + get_settings.cache_clear() + s = get_settings() + assert s.sources_yaml_path == "config/sources.yaml" + assert s.resolver_rules_path == "config/resolver-rules.yaml" + assert s.probe_size_mb == 32 + assert s.probe_timeout_s == 8.0 + assert s.chunk_level_min_file_mb == 100 + assert s.speed_ewma_alpha == 0.3 + assert s.sha_mismatch_blacklist_hours == 24 + assert s.rebalance_interval_seconds == 60.0 + get_settings.cache_clear() +``` + +- [ ] **Step 2: Run** `uv run pytest tests/test_config.py::test_sp2_source_settings_defaults -v` → FAIL (no attrs). + +- [ ] **Step 3: Implement** — in `src/dlw/config.py`, after the Phase 3 SP1 block (after `auth_tenant_rules_json`), add: +```python + # Phase 3 SP2 — multi-source + sources_yaml_path: str = Field(default="config/sources.yaml") + resolver_rules_path: str = Field(default="config/resolver-rules.yaml") + probe_size_mb: int = Field(default=32, ge=1, le=256) + probe_timeout_s: float = Field(default=8.0, ge=1.0, le=60.0) + probe_history_weight: float = Field(default=0.3, ge=0.0, le=1.0) + combo_overhead_per_source_pct: float = Field(default=2.0, ge=0.0, le=50.0) + chunk_level_min_file_mb: int = Field(default=100, ge=1) + speed_ewma_alpha: float = Field(default=0.3, ge=0.0, le=1.0) + blacklist_5xx_count: int = Field(default=3, ge=1) + blacklist_minutes: int = Field(default=5, ge=1) + blacklist_max_minutes: int = Field(default=30, ge=1) + sha_mismatch_blacklist_hours: int = Field(default=24, ge=1) + rebalance_interval_seconds: float = Field(default=60.0, ge=5.0, le=600.0) + degradation_trigger_threshold: float = Field(default=0.3, ge=0.0, le=1.0) +``` +Add `"pyyaml>=6,<7"` to `pyproject.toml` `[project] dependencies` (after `pydantic-settings`). Run `uv lock` then `uv sync --all-groups`. NOTE: `pyyaml` is already transitively locked (via `huggingface_hub`/`uvicorn[standard]`), so `import yaml` works today and the `uv.lock` diff will be near-empty — this step promotes it to a *direct* dependency for correctness; a tiny/empty `uv.lock` diff is expected, not an error. + +- [ ] **Step 4: Run** `uv run pytest tests/test_config.py -v` → all PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/config.py pyproject.toml uv.lock tests/test_config.py +git commit -m "feat(sp2): multi-source config + pyyaml dep" +``` + +--- + +### Task 2: `SourceDriver` Protocol + dataclasses + +**Files:** Create `src/dlw/sources/__init__.py`, `src/dlw/sources/base.py`; Test `tests/sources/__init__.py`, `tests/sources/test_base.py` + +- [ ] **Step 1: Write the failing test** — create empty `tests/sources/__init__.py`, then `tests/sources/test_base.py`: +```python +"""SourceDriver Protocol + dataclasses (Phase 3 SP2).""" +from __future__ import annotations + +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +def test_sourcefile_defaults(): + f = SourceFile(filename="model.safetensors", size=10, sha256=None, + download_ref="r") + assert f.filename == "model.safetensors" and f.sha256 is None + + +def test_manifest_holds_files(): + m = SourceManifest(source_id="huggingface", repo_id_in_source="o/r", + revision_in_source="abc", files=[ + SourceFile("a", 1, "x" * 64, "ref")], + has_lfs_sha256=True) + assert m.source_id == "huggingface" and len(m.files) == 1 + + +def test_health_and_token(): + assert SourceHealth(ok=True, latency_ms=12.0).ok is True + t = SourceToken(scheme="bearer", value="secret") + assert t.value == "secret" and "secret" not in repr(t) +``` + +- [ ] **Step 2: Run** `uv run pytest tests/sources/test_base.py -v` → FAIL (module missing). + +- [ ] **Step 3: Implement** — create empty `src/dlw/sources/__init__.py`, then `src/dlw/sources/base.py`: +```python +"""SourceDriver abstraction (Phase 3 SP2; design doc 06 §1.3).""" +from __future__ import annotations + +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class SourceFile: + filename: str # normalized HF-style path (cross-source key) + size: int | None + sha256: str | None # only HF / hf_mirror populate this + download_ref: str # source-specific URL or object key + + +@dataclass(frozen=True) +class SourceManifest: + source_id: str + repo_id_in_source: str + revision_in_source: str + files: list[SourceFile] + has_lfs_sha256: bool + + +@dataclass(frozen=True) +class SourceHealth: + ok: bool + latency_ms: float + + +@dataclass(frozen=True) +class SourceToken: + scheme: str # "bearer" | "none" + value: str = field(default="", repr=False) # never in repr/logs (INV 2) + + +@runtime_checkable +class SourceDriver(Protocol): + id: str + domain: str + provides_sha256: bool + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: ... + + def download_url(self, file: SourceFile) -> str: ... + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: ... + + async def health_check(self) -> SourceHealth: ... + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: ... +``` +(Design note: `download_range` from the spec is realized as `download_url(file)` + the proxy issuing the ranged GET — keeps drivers pure/sync-URL and centralizes streaming/retry in `source_proxy.py`. `auth_token` returns the controller-side cred the proxy injects; the executor never sees it.) + +- [ ] **Step 4: Run** `uv run pytest tests/sources/test_base.py -v` → 3 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/sources/__init__.py src/dlw/sources/base.py tests/sources/ +git commit -m "feat(sp2): SourceDriver Protocol + manifest/file/token dataclasses" +``` + +--- + +### Task 3: HuggingFace + hf_mirror drivers + +**Files:** Create `src/dlw/sources/huggingface.py`, `src/dlw/sources/hf_mirror.py`; Test `tests/sources/test_hf_drivers.py` + +- [ ] **Step 1: Write the failing test** — `tests/sources/test_hf_drivers.py`: +```python +"""HF + hf_mirror drivers (Phase 3 SP2).""" +from __future__ import annotations + +import pytest + +from dlw.services.hf_metadata import RepoFile +from dlw.sources.hf_mirror import HfMirrorDriver +from dlw.sources.huggingface import HuggingFaceDriver + + +@pytest.fixture +def _patch_list(monkeypatch): + async def fake(repo_id, revision, *, hf_endpoint, hf_token): + assert revision == "abc" + return [RepoFile(path="model.safetensors", size=64, sha256="a" * 64), + RepoFile(path="config.json", size=4, sha256=None)] + monkeypatch.setattr("dlw.sources.huggingface.list_repo_tree", fake) + monkeypatch.setattr("dlw.sources.hf_mirror.list_repo_tree", fake) + + +async def test_hf_resolve(_patch_list): + d = HuggingFaceDriver(base_url="https://huggingface.co", hf_token="tok") + m = await d.resolve("o/r", "abc") + assert m is not None + assert m.source_id == "huggingface" and m.has_lfs_sha256 is True + assert {f.filename for f in m.files} == {"model.safetensors", "config.json"} + assert d.provides_sha256 is True + assert d.download_url(m.files[0]).endswith( + "/o/r/resolve/abc/model.safetensors") + assert d.auth_token("tok").value == "tok" + + +async def test_hf_mirror_no_token_and_base(_patch_list): + d = HfMirrorDriver(base_url="https://hf-mirror.com") + m = await d.resolve("o/r", "abc") + assert m.source_id == "hf_mirror" + assert d.download_url(m.files[0]).startswith("https://hf-mirror.com/") + assert d.auth_token("tok").scheme == "none" + + +async def test_hf_mirror_gated_returns_none(monkeypatch): + from dlw.services.hf_metadata import HfPrivateOrAuthRequired + + async def gated(*a, **k): + raise HfPrivateOrAuthRequired("gated") + monkeypatch.setattr("dlw.sources.hf_mirror.list_repo_tree", gated) + d = HfMirrorDriver(base_url="https://hf-mirror.com") + assert await d.resolve("o/gated", "abc") is None +``` + +- [ ] **Step 2: Run** `uv run pytest tests/sources/test_hf_drivers.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/sources/huggingface.py`: +```python +"""HuggingFace SourceDriver — wraps the existing hf_metadata path (SP2).""" +from __future__ import annotations + +from decimal import Decimal + +from dlw.services.hf_metadata import ( + HfNetworkError, + HfPrivateOrAuthRequired, + RepoNotFound, + list_repo_tree, +) +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class HuggingFaceDriver: + id = "huggingface" + domain = "huggingface.co" + provides_sha256 = True + + def __init__(self, *, base_url: str, hf_token: str | None) -> None: + self._base = base_url.rstrip("/") + self._token = hf_token + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + try: + files = await list_repo_tree( + repo_id, revision, + hf_endpoint=self._base, hf_token=self._token) + except (RepoNotFound,): + return None + except (HfPrivateOrAuthRequired, HfNetworkError): + raise + sf = [SourceFile(filename=f.path, size=f.size, sha256=f.sha256, + download_ref=f"{repo_id}/resolve/{revision}/{f.path}") + for f in files] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, + has_lfs_sha256=any(f.sha256 for f in sf)) + + def download_url(self, file: SourceFile) -> str: + return f"{self._base}/{file.download_ref}" + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + tok = tenant_hf_token or self._token + return (SourceToken(scheme="bearer", value=tok) if tok + else SourceToken(scheme="none")) + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal("0.09") * Decimal(n_bytes) / Decimal(1_000_000_000) +``` +`src/dlw/sources/hf_mirror.py`: +```python +"""hf-mirror.com SourceDriver — HF-compatible, no token, gated→skip (SP2).""" +from __future__ import annotations + +from decimal import Decimal + +from dlw.services.hf_metadata import ( + HfNetworkError, + HfPrivateOrAuthRequired, + RepoNotFound, + list_repo_tree, +) +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class HfMirrorDriver: + id = "hf_mirror" + domain = "hf-mirror.com" + provides_sha256 = True + + def __init__(self, *, base_url: str) -> None: + self._base = base_url.rstrip("/") + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + try: + files = await list_repo_tree( + repo_id, revision, hf_endpoint=self._base, hf_token=None) + except RepoNotFound: + return None + except HfPrivateOrAuthRequired: + return None # gated: public mirror can't serve it — skip + except HfNetworkError: + raise + sf = [SourceFile(filename=f.path, size=f.size, sha256=f.sha256, + download_ref=f"{repo_id}/resolve/{revision}/{f.path}") + for f in files] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, + has_lfs_sha256=any(f.sha256 for f in sf)) + + def download_url(self, file: SourceFile) -> str: + return f"{self._base}/{file.download_ref}" + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + return SourceToken(scheme="none") + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal(0) +``` + +- [ ] **Step 4: Run** `uv run pytest tests/sources/test_hf_drivers.py -v` → 3 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/sources/huggingface.py src/dlw/sources/hf_mirror.py tests/sources/test_hf_drivers.py +git commit -m "feat(sp2): HuggingFace + hf_mirror SourceDrivers" +``` + +--- + +### Task 4: ModelScope driver + +**Files:** Create `src/dlw/sources/modelscope.py`; Test `tests/sources/test_modelscope_driver.py` + +- [ ] **Step 1: Write the failing test** — `tests/sources/test_modelscope_driver.py`: +```python +"""ModelScope driver (Phase 3 SP2).""" +from __future__ import annotations + +import httpx +import pytest + +from dlw.sources.modelscope import ModelScopeDriver + + +def _handler(request: httpx.Request) -> httpx.Response: + assert "modelscope.cn" in str(request.url) + if "/repo?Revision=" in str(request.url) and "FilePath" not in str(request.url): + return httpx.Response(200, json={"Data": {"Files": [ + {"Path": "model.safetensors", "Size": 64}, + {"Path": "config.json", "Size": 4}]}}) + return httpx.Response(404) + + +@pytest.fixture +def _drv(): + return ModelScopeDriver( + base_url="https://www.modelscope.cn", + transport=httpx.MockTransport(_handler)) + + +async def test_modelscope_resolve_no_sha(_drv): + m = await _drv.resolve("qwen/Qwen3-7B", "v1") + assert m is not None + assert m.source_id == "modelscope" and m.has_lfs_sha256 is False + assert all(f.sha256 is None for f in m.files) + assert {f.filename for f in m.files} == {"model.safetensors", "config.json"} + assert _drv.provides_sha256 is False + + +async def test_modelscope_download_url(_drv): + m = await _drv.resolve("qwen/Qwen3-7B", "v1") + url = _drv.download_url(m.files[0]) + assert "FilePath=model.safetensors" in url and "Revision=v1" in url + + +async def test_modelscope_missing_repo_returns_none(): + d = ModelScopeDriver( + base_url="https://www.modelscope.cn", + transport=httpx.MockTransport(lambda r: httpx.Response(404))) + assert await d.resolve("no/such", "v1") is None +``` + +- [ ] **Step 2: Run** `uv run pytest tests/sources/test_modelscope_driver.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/sources/modelscope.py`: +```python +"""ModelScope SourceDriver — raw httpx, no official sha256 (SP2; doc §1.9.3).""" +from __future__ import annotations + +from decimal import Decimal +from urllib.parse import quote + +import httpx + +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class ModelScopeDriver: + id = "modelscope" + domain = "modelscope.cn" + provides_sha256 = False + + def __init__(self, *, base_url: str, + transport: httpx.AsyncBaseTransport | None = None) -> None: + self._base = base_url.rstrip("/") + self._transport = transport + + def _client(self) -> httpx.AsyncClient: + return httpx.AsyncClient(timeout=30, transport=self._transport) + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + url = f"{self._base}/api/v1/models/{repo_id}/repo?Revision={revision}" + async with self._client() as c: + r = await c.get(url) + if r.status_code == 404: + return None + r.raise_for_status() + data = r.json().get("Data", {}).get("Files", []) + sf = [SourceFile(filename=d["Path"], size=d.get("Size"), + sha256=None, + download_ref=f"{repo_id}|{revision}|{d['Path']}") + for d in data] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, has_lfs_sha256=False) + + def download_url(self, file: SourceFile) -> str: + repo, rev, path = file.download_ref.split("|", 2) + return (f"{self._base}/api/v1/models/{repo}/repo" + f"?Revision={rev}&FilePath={quote(path)}") + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + return SourceToken(scheme="none") + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal(0) +``` + +- [ ] **Step 4: Run** `uv run pytest tests/sources/test_modelscope_driver.py -v` → 3 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/sources/modelscope.py tests/sources/test_modelscope_driver.py +git commit -m "feat(sp2): ModelScope SourceDriver" +``` + +--- + +### Task 5: Registry + `sources.yaml` + +**Files:** Create `src/dlw/sources/registry.py`, `config/sources.yaml`; Test `tests/sources/test_registry.py` + +- [ ] **Step 1: Write the failing test** — `tests/sources/test_registry.py`: +```python +"""Source registry from sources.yaml (Phase 3 SP2).""" +from __future__ import annotations + +from dlw.sources.registry import load_registry + +_YAML = """ +sources: + - id: huggingface + enabled: true + driver: huggingface + config: {base_url: https://huggingface.co} + - id: hf_mirror + enabled: true + driver: hf_mirror + config: {base_url: https://hf-mirror.com} + - id: modelscope + enabled: false + driver: modelscope + config: {base_url: https://www.modelscope.cn} + - id: corp + enabled: true + driver: s3_mirror + config: {} +regional_defaults: + cn-north: [hf_mirror, modelscope, huggingface] +""" + + +def test_only_enabled_supported(tmp_path): + p = tmp_path / "s.yaml" + p.write_text(_YAML, encoding="utf-8") + reg = load_registry(str(p), hf_token="tk") + assert set(reg.enabled_ids()) == {"huggingface", "hf_mirror"} # ms off, s3 unsupported + assert reg.get("huggingface").id == "huggingface" + assert reg.get("missing") is None + assert reg.regional_defaults["cn-north"][0] == "hf_mirror" + + +def test_modelscope_enabled(tmp_path): + p = tmp_path / "s.yaml" + p.write_text(_YAML.replace("id: modelscope\n enabled: false", + "id: modelscope\n enabled: true"), + encoding="utf-8") + reg = load_registry(str(p), hf_token=None) + assert "modelscope" in reg.enabled_ids() +``` + +- [ ] **Step 2: Run** `uv run pytest tests/sources/test_registry.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/sources/registry.py`: +```python +"""sources.yaml → enabled SourceDriver registry (Phase 3 SP2).""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import yaml + +from dlw.sources.base import SourceDriver +from dlw.sources.hf_mirror import HfMirrorDriver +from dlw.sources.huggingface import HuggingFaceDriver +from dlw.sources.modelscope import ModelScopeDriver + +_SUPPORTED = {"huggingface", "hf_mirror", "modelscope"} + + +@dataclass +class SourceRegistry: + _drivers: dict[str, SourceDriver] + regional_defaults: dict[str, list[str]] = field(default_factory=dict) + + def enabled_ids(self) -> list[str]: + return list(self._drivers.keys()) + + def get(self, source_id: str) -> SourceDriver | None: + return self._drivers.get(source_id) + + +def _build(driver: str, cfg: dict[str, Any], + hf_token: str | None) -> SourceDriver | None: + if driver == "huggingface": + return HuggingFaceDriver( + base_url=cfg.get("base_url", "https://huggingface.co"), + hf_token=hf_token) + if driver == "hf_mirror": + return HfMirrorDriver( + base_url=cfg.get("base_url", "https://hf-mirror.com")) + if driver == "modelscope": + return ModelScopeDriver( + base_url=cfg.get("base_url", "https://www.modelscope.cn")) + return None # unsupported driver (s3_mirror/wisemodel/...) — skip + + +def load_registry(path: str, *, hf_token: str | None) -> SourceRegistry: + with open(path, encoding="utf-8") as fh: + doc = yaml.safe_load(fh) or {} + drivers: dict[str, SourceDriver] = {} + for entry in doc.get("sources", []): + if not entry.get("enabled"): + continue + if entry.get("driver") not in _SUPPORTED: + continue + d = _build(entry["driver"], entry.get("config") or {}, hf_token) + if d is not None: + drivers[entry["id"]] = d + return SourceRegistry(_drivers=drivers, + regional_defaults=doc.get("regional_defaults", {})) +``` +Create `config/sources.yaml`: +```yaml +sources: + - id: huggingface + enabled: true + driver: huggingface + config: {base_url: "https://huggingface.co", timeout_seconds: 30} + cost_per_gb_egress: 0.09 + - id: hf_mirror + enabled: true + driver: hf_mirror + config: {base_url: "https://hf-mirror.com", timeout_seconds: 30} + cost_per_gb_egress: 0.0 + - id: modelscope + enabled: true + driver: modelscope + config: {base_url: "https://www.modelscope.cn", timeout_seconds: 30} + cost_per_gb_egress: 0.0 +balancing: + speed_ewma_alpha: 0.3 + chunk_level_min_file_mb: 100 +regional_defaults: + cn-north: ["hf_mirror", "modelscope", "huggingface"] + us-east: ["huggingface"] +``` + +- [ ] **Step 4: Run** `uv run pytest tests/sources/test_registry.py -v` → 2 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/sources/registry.py config/sources.yaml tests/sources/test_registry.py +git commit -m "feat(sp2): source registry + sources.yaml" +``` + +--- + +### Task 6: NameResolver + `resolver-rules.yaml` + +**Files:** Create `src/dlw/sources/name_resolver.py`, `config/resolver-rules.yaml`; Test `tests/sources/test_name_resolver.py` + +- [ ] **Step 1: Write the failing test** — `tests/sources/test_name_resolver.py`: +```python +"""NameResolver 3-tier (Phase 3 SP2; doc §1.5).""" +from __future__ import annotations + +from dlw.sources.name_resolver import NameResolver + +_RULES = """ +identity_organizations: [deepseek-ai, Qwen, THUDM] +aliases: + - hf_org: meta-llama + modelscope_org: LLM-Research + transform: "Meta-{name}" +per_model_overrides: + - hf: "weird-org/weird-model" + modelscope: "diff-org/diff-name" +""" + + +def _r(tmp_path): + p = tmp_path / "rr.yaml" + p.write_text(_RULES, encoding="utf-8") + return NameResolver.from_file(str(p)) + + +def test_huggingface_is_always_identity(tmp_path): + r = _r(tmp_path) + assert r.resolve("huggingface", "any-org/any-model") == "any-org/any-model" + + +def test_identity_org(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "deepseek-ai/DeepSeek-V3") == "deepseek-ai/DeepSeek-V3" + + +def test_alias_transform(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "meta-llama/Llama-3.1-8B") == "LLM-Research/Meta-Llama-3.1-8B" + + +def test_per_model_override(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "weird-org/weird-model") == "diff-org/diff-name" + + +def test_unknown_returns_none(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "rando-org/rando-model") is None +``` + +- [ ] **Step 2: Run** `uv run pytest tests/sources/test_name_resolver.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/sources/name_resolver.py`: +```python +"""3-tier source name resolution (Phase 3 SP2; doc §1.5). + +Tier 1 identity (HF, or org in identity_organizations); tier 2 alias / +per-model rules from resolver-rules.yaml; tier 3 source search-API (deferred +to a stub that returns None — wiring point for v2.1; cache scaffold present).""" +from __future__ import annotations + +from dataclasses import dataclass + +import yaml + + +@dataclass +class _Alias: + hf_org: str + ms_org: str + transform: str # e.g. "Meta-{name}" + + +class NameResolver: + def __init__(self, *, identity_orgs: set[str], aliases: list[_Alias], + overrides: dict[str, str]) -> None: + self._identity = identity_orgs + self._aliases = {a.hf_org: a for a in aliases} + self._overrides = overrides # "hf_repo" -> "src_repo" + self._search_cache: dict[tuple[str, str], str] = {} + + @classmethod + def from_file(cls, path: str) -> NameResolver: + with open(path, encoding="utf-8") as fh: + doc = yaml.safe_load(fh) or {} + aliases = [_Alias(a["hf_org"], a["modelscope_org"], a["transform"]) + for a in doc.get("aliases", [])] + overrides = {o["hf"]: o["modelscope"] + for o in doc.get("per_model_overrides", [])} + return cls(identity_orgs=set(doc.get("identity_organizations", [])), + aliases=aliases, overrides=overrides) + + def resolve(self, source_id: str, hf_repo_id: str) -> str | None: + if source_id == "huggingface" or source_id == "hf_mirror": + return hf_repo_id + if hf_repo_id in self._overrides: + return self._overrides[hf_repo_id] + org, _, name = hf_repo_id.partition("/") + if org in self._identity: + return hf_repo_id + a = self._aliases.get(org) + if a is not None: + return f"{a.ms_org}/{a.transform.format(name=name)}" + return self._search_cache.get((source_id, hf_repo_id)) # tier 3 stub +``` +Create `config/resolver-rules.yaml`: +```yaml +identity_organizations: + - deepseek-ai + - Qwen + - 01-ai + - THUDM + - baichuan-inc + - mistralai +aliases: + - hf_org: meta-llama + modelscope_org: LLM-Research + transform: "Meta-{name}" +per_model_overrides: [] +``` + +- [ ] **Step 4: Run** `uv run pytest tests/sources/test_name_resolver.py -v` → 5 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/sources/name_resolver.py config/resolver-rules.yaml tests/sources/test_name_resolver.py +git commit -m "feat(sp2): NameResolver 3-tier + resolver-rules.yaml" +``` + +--- + +# Milestone M2 — Schema + Models + +### Task 7: Models + migration + +**Files:** Create `src/dlw/db/models/source.py`, migration; Modify `src/dlw/db/models/__init__.py`, `src/dlw/schemas/task.py`; Test `tests/db/test_p3sp2_migration.py` + +- [ ] **Step 1: Write the failing test** — `tests/db/test_p3sp2_migration.py`: +```python +"""SP2 migration: 3 tables + task/subtask source columns.""" +from __future__ import annotations + +import pytest +from sqlalchemy import text + +import dlw.db.models # noqa: F401 + +pytestmark = pytest.mark.slow + + +async def test_tables_and_columns(engine): + from dlw.db.base import Base + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + names = {r[0] for r in await conn.execute(text( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema='public'"))} + assert {"subtask_chunks", "source_speed_samples", + "source_blacklist"} <= names + cols = {r[0] for r in await conn.execute(text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='download_tasks'"))} + assert {"source_strategy", "source_blacklist", + "trust_non_hf_sha256"} <= cols + scols = {r[0] for r in await conn.execute(text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='file_subtasks'"))} + assert {"source_id", "is_chunked"} <= scols + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) +``` + +- [ ] **Step 2: Run** `uv run pytest tests/db/test_p3sp2_migration.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/db/models/source.py`: +```python +"""Multi-source models (Phase 3 SP2; doc 06 §1.4/§1.7).""" +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import ( + BigInteger, + DateTime, + Float, + ForeignKey, + Integer, + String, + UniqueConstraint, + func, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from dlw.db.base import Base + + +class SubtaskChunk(Base): + __tablename__ = "subtask_chunks" + __table_args__ = (UniqueConstraint("subtask_id", "chunk_index"),) + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + subtask_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("file_subtasks.id", ondelete="CASCADE"), nullable=False) + chunk_index: Mapped[int] = mapped_column(Integer, nullable=False) + byte_start: Mapped[int] = mapped_column(BigInteger, nullable=False) + byte_end: Mapped[int] = mapped_column(BigInteger, nullable=False) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + status: Mapped[str] = mapped_column(String(16), nullable=False) + sha256_partial: Mapped[str | None] = mapped_column(String(64), nullable=True) + bytes_done: Mapped[int] = mapped_column(BigInteger, default=0, + nullable=False) + + +class SourceSpeedSample(Base): + __tablename__ = "source_speed_samples" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + executor_id: Mapped[str] = mapped_column(String(64), nullable=False) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + measured_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False) + bytes_per_sec: Mapped[float] = mapped_column(Float, nullable=False) + sample_size: Mapped[int] = mapped_column(BigInteger, nullable=False) + is_active_probe: Mapped[bool] = mapped_column(default=False, + nullable=False) + + +class SourceBlacklist(Base): + __tablename__ = "source_blacklist" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + repo_id: Mapped[str | None] = mapped_column(String(256), nullable=True) + filename: Mapped[str | None] = mapped_column(String(512), nullable=True) + until: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False) + reason: Mapped[str] = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False) +``` +In `src/dlw/db/models/__init__.py` add `from dlw.db.models.source import SourceBlacklist, SourceSpeedSample, SubtaskChunk` and add `"SourceBlacklist", "SourceSpeedSample", "SubtaskChunk"` to `__all__` (keep sorted). In `src/dlw/schemas/task.py` `TaskCreate` add fields: +```python + source_strategy: str = Field(default="auto_balance", max_length=32) + source_blacklist: list[str] = Field(default_factory=list) + trust_non_hf_sha256: bool = Field(default=False) +``` +Generate migration: `uv run alembic revision -m "p3sp2 multi source"`. Set `down_revision = "a4bed702cdb3"`, imports `import sqlalchemy as sa`, `from alembic import op`, `from sqlalchemy.dialects import postgresql`. Body: +```python +def upgrade() -> None: + op.add_column("download_tasks", sa.Column( + "source_strategy", sa.String(32), nullable=False, + server_default="auto_balance")) + op.add_column("download_tasks", sa.Column( + "source_blacklist", postgresql.JSONB(), nullable=False, + server_default="[]")) + op.add_column("download_tasks", sa.Column( + "trust_non_hf_sha256", sa.Boolean(), nullable=False, + server_default=sa.false())) + op.add_column("file_subtasks", sa.Column( + "source_id", sa.String(32), nullable=True)) + op.add_column("file_subtasks", sa.Column( + "is_chunked", sa.Boolean(), nullable=False, + server_default=sa.false())) + op.create_table( + "subtask_chunks", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("subtask_id", postgresql.UUID(as_uuid=True), + sa.ForeignKey("file_subtasks.id", ondelete="CASCADE"), + nullable=False), + sa.Column("chunk_index", sa.Integer(), nullable=False), + sa.Column("byte_start", sa.BigInteger(), nullable=False), + sa.Column("byte_end", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("status", sa.String(16), nullable=False), + sa.Column("sha256_partial", sa.String(64), nullable=True), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, + server_default="0"), + sa.UniqueConstraint("subtask_id", "chunk_index"), + ) + op.create_index("idx_chunk_sub_status", "subtask_chunks", + ["subtask_id", "status"]) + op.create_table( + "source_speed_samples", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("executor_id", sa.String(64), nullable=False), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("measured_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("bytes_per_sec", sa.Float(), nullable=False), + sa.Column("sample_size", sa.BigInteger(), nullable=False), + sa.Column("is_active_probe", sa.Boolean(), nullable=False, + server_default=sa.false()), + ) + op.create_index("idx_speed_recent", "source_speed_samples", + ["executor_id", "source_id", "measured_at"]) + op.create_table( + "source_blacklist", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("repo_id", sa.String(256), nullable=True), + sa.Column("filename", sa.String(512), nullable=True), + sa.Column("until", sa.DateTime(timezone=True), nullable=False), + sa.Column("reason", sa.String(64), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + ) + op.create_index("idx_blacklist_lookup", "source_blacklist", + ["source_id", "repo_id", "until"]) + + +def downgrade() -> None: + op.drop_index("idx_blacklist_lookup", "source_blacklist") + op.drop_table("source_blacklist") + op.drop_index("idx_speed_recent", "source_speed_samples") + op.drop_table("source_speed_samples") + op.drop_index("idx_chunk_sub_status", "subtask_chunks") + op.drop_table("subtask_chunks") + op.drop_column("file_subtasks", "is_chunked") + op.drop_column("file_subtasks", "source_id") + op.drop_column("download_tasks", "trust_non_hf_sha256") + op.drop_column("download_tasks", "source_blacklist") + op.drop_column("download_tasks", "source_strategy") +``` +Also add the matching SQLAlchemy columns to the existing models so `Base.metadata.create_all` (used by tests) builds them: in `src/dlw/db/models/task.py` `DownloadTask` add `source_strategy: Mapped[str] = mapped_column(String(32), default="auto_balance", nullable=False)`, `source_blacklist: Mapped[list] = mapped_column(JSONB, default=list, nullable=False)` (import `from sqlalchemy.dialects.postgresql import JSONB`), `trust_non_hf_sha256: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)`; in `FileSubTask` add `source_id: Mapped[str | None] = mapped_column(String(32), nullable=True)`, `is_chunked: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)`. (Use the imports already at the top of `task.py`; add `JSONB` if absent.) + +- [ ] **Step 4: Run** `uv run pytest tests/db/test_p3sp2_migration.py -v` (PASS), then `uv run alembic upgrade head && uv run alembic downgrade -1 && uv run alembic upgrade head` (clean). + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/db/models/source.py src/dlw/db/models/__init__.py src/dlw/db/models/task.py src/dlw/schemas/task.py src/dlw/alembic/versions/ tests/db/test_p3sp2_migration.py +git commit -m "feat(sp2): SubtaskChunk/SourceSpeedSample/SourceBlacklist models + migration" +``` + +--- + +# Milestone M3 — Planner + +### Task 8: Speed service (EWMA fusion) + +**Files:** Create `src/dlw/services/source_speed.py`; Test `tests/services/test_source_speed.py` + +Controller-side probe (per spec banner ruling 6d): the controller itself does one small ranged GET per source via the driver's `download_url`, times it, returns bytes/sec. Per-executor probe-through-proxy is deferred to v2.1. + +- [ ] **Step 1: Write the failing test** — `tests/services/test_source_speed.py`: +```python +"""Speed EWMA fusion + controller-side probe (Phase 3 SP2; doc §1.7/§1.8).""" +from __future__ import annotations + +import httpx +import pytest + +from dlw.services.source_speed import ( + fuse_ewma, + pick_probe_size_bytes, + probe_source_speed, +) +from dlw.sources.base import SourceFile + + +def test_fuse_no_history_uses_live(): + assert fuse_ewma(live=1000.0, hist=None, hist_weight=0.3) == 1000.0 + + +def test_fuse_blends(): + assert fuse_ewma(live=1000.0, hist=500.0, hist_weight=0.3) == 850.0 + + +def test_probe_size(): + assert pick_probe_size_bytes(probe_size_mb=32) == 32 * 1024 * 1024 + + +class _Drv: + def download_url(self, f): + return "https://src/x" + + def auth_token(self, t): + from dlw.sources.base import SourceToken + return SourceToken(scheme="none") + + +async def test_probe_returns_positive_speed(): + transport = httpx.MockTransport( + lambda r: httpx.Response(206, content=b"x" * 4096)) + bps = await probe_source_speed( + _Drv(), SourceFile("m", 4096, None, "ref"), + probe_bytes=4096, timeout_s=5.0, hf_token=None, transport=transport) + assert bps > 0.0 + + +async def test_probe_failure_returns_zero(): + def boom(r): + raise httpx.ConnectError("down") + bps = await probe_source_speed( + _Drv(), SourceFile("m", 4096, None, "ref"), + probe_bytes=4096, timeout_s=5.0, hf_token=None, + transport=httpx.MockTransport(boom)) + assert bps == 0.0 +``` + +- [ ] **Step 2: Run** `uv run pytest tests/services/test_source_speed.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/services/source_speed.py`: +```python +"""Source speed: controller-side probe + EWMA fusion (Phase 3 SP2).""" +from __future__ import annotations + +import time +from typing import Any + +import httpx + + +def fuse_ewma(*, live: float, hist: float | None, + hist_weight: float) -> float: + if hist is None: + return live + return (1.0 - hist_weight) * live + hist_weight * hist + + +def pick_probe_size_bytes(*, probe_size_mb: int) -> int: + return probe_size_mb * 1024 * 1024 + + +async def probe_source_speed( + driver: Any, file: Any, *, probe_bytes: int, timeout_s: float, + hf_token: str | None, + transport: httpx.AsyncBaseTransport | None = None, +) -> float: + """One ranged GET (controller→source) timing bytes/sec. 0.0 on any + failure (effect: that source is treated as unavailable for this task).""" + url = driver.download_url(file) + tok = driver.auth_token(hf_token) + headers = {"Range": f"bytes=0-{max(0, probe_bytes - 1)}"} + if tok.scheme == "bearer" and tok.value: + headers["Authorization"] = f"Bearer {tok.value}" + try: + async with httpx.AsyncClient(timeout=timeout_s, transport=transport, + follow_redirects=True) as c: + start = time.monotonic() + recv = 0 + async with c.stream("GET", url, headers=headers) as resp: + if resp.status_code >= 400: + return 0.0 + async for buf in resp.aiter_bytes(64 * 1024): + recv += len(buf) + elapsed = time.monotonic() - start + return recv / elapsed if elapsed > 0 and recv > 0 else 0.0 + except Exception: + return 0.0 +``` + +- [ ] **Step 4: Run** `uv run pytest tests/services/test_source_speed.py -v` → 5 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/services/source_speed.py tests/services/test_source_speed.py +git commit -m "feat(sp2): source speed EWMA fusion" +``` + +--- + +### Task 9: LPT assignment + optimal-combo + +**Files:** Create `src/dlw/services/source_combo.py`; Test `tests/services/test_source_combo.py` + +- [ ] **Step 1: Write the failing test** — `tests/services/test_source_combo.py`: +```python +"""LPT greedy + optimal-combo (Phase 3 SP2; doc §1.6/§1.8, OR-V21-04).""" +from __future__ import annotations + +from dlw.services.source_combo import assign_files_lpt, solve_optimal_combo + + +def test_lpt_balances_by_completion_time(): + files = {"a": 100, "b": 100, "c": 50} + speeds = {"s1": 10.0, "s2": 5.0} + assign = assign_files_lpt(files, speeds) + assert set(assign.values()) <= {"s1", "s2"} + # largest files go to the faster source first + assert assign["a"] == "s1" + + +def test_lpt_single_source_degenerate(): + assign = assign_files_lpt({"a": 1, "b": 2}, {"only": 7.0}) + assert assign == {"a": "only", "b": "only"} + + +def test_combo_excludes_slow_source_by_overhead(): + # one fast source + one trivially-slow source: combo should drop the slow + files = {"f": 1_000_000_000} + speeds = {"fast": 1_000_000_000.0, "slow": 1.0} + combo = solve_optimal_combo(speeds, files, overhead_pct=2.0) + assert combo == ["fast"] + + +def test_combo_uses_both_when_comparable(): + files = {"a": 100, "b": 100} + speeds = {"s1": 10.0, "s2": 10.0} + combo = solve_optimal_combo(speeds, files, overhead_pct=2.0) + assert set(combo) == {"s1", "s2"} +``` + +- [ ] **Step 2: Run** `uv run pytest tests/services/test_source_combo.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/services/source_combo.py`: +```python +"""File→source assignment: size-descending greedy heuristic (NOT bounded- +optimal LPT — doc OR-V21-04) + fastest-K combo with overhead penalty.""" +from __future__ import annotations + + +def assign_files_lpt( + files: dict[str, int], source_speeds: dict[str, float] +) -> dict[str, str]: + """files: {filename: size}; source_speeds: {source_id: bytes/sec}. + Returns {filename: source_id}. Largest-first; each file to the source + with the earliest projected completion (load+size)/speed.""" + load = {sid: 0.0 for sid in source_speeds} + out: dict[str, str] = {} + for fn in sorted(files, key=lambda k: -files[k]): + size = files[fn] + best = min(source_speeds, + key=lambda sid: (load[sid] + size) / source_speeds[sid]) + out[fn] = best + load[best] += size + return out + + +def _eta(files: dict[str, int], speeds: dict[str, float]) -> float: + assign = assign_files_lpt(files, speeds) + load = {sid: 0.0 for sid in speeds} + for fn, sid in assign.items(): + load[sid] += files[fn] + return max((load[sid] / speeds[sid] for sid in speeds), default=0.0) + + +def solve_optimal_combo( + source_speeds: dict[str, float], files: dict[str, int], + *, overhead_pct: float +) -> list[str]: + ranked = sorted(source_speeds, key=lambda s: -source_speeds[s]) + best_eta = float("inf") + best: list[str] = ranked[:1] + for k in range(1, len(ranked) + 1): + combo = ranked[:k] + sub = {s: source_speeds[s] for s in combo} + eta = _eta(files, sub) * (1 + 0.01 * overhead_pct * (k - 1)) + if eta < best_eta: + best_eta, best = eta, combo + elif k > 1 and eta > best_eta * 1.05: + break + return best +``` + +- [ ] **Step 4: Run** `uv run pytest tests/services/test_source_combo.py -v` → 4 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/services/source_combo.py tests/services/test_source_combo.py +git commit -m "feat(sp2): LPT greedy assignment + optimal-combo selection" +``` + +--- + +### Task 10: Blacklist service + +**Files:** Create `src/dlw/services/source_blacklist.py`; Test `tests/services/test_source_blacklist.py` + +- [ ] **Step 1: Write the failing test** — `tests/services/test_source_blacklist.py`: +```python +"""Source blacklist transitions (Phase 3 SP2; doc §1.7).""" +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SourceBlacklist +from dlw.services.source_blacklist import ( + blacklist_file, + is_blacklisted, +) + +pytestmark = pytest.mark.slow + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + yield async_sessionmaker(engine, expire_on_commit=False) + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_blacklist_and_check(factory): + async with factory() as s: + await blacklist_file(s, source_id="modelscope", repo_id="o/r", + filename="m.safetensors", hours=24, + reason="sha_mismatch") + await s.commit() + assert await is_blacklisted(s, "modelscope", "o/r", + "m.safetensors") is True + assert await is_blacklisted(s, "modelscope", "o/r", + "other.bin") is False + + +async def test_expired_not_blacklisted(factory): + async with factory() as s: + s.add(SourceBlacklist(source_id="modelscope", repo_id="o/r", + filename="m", reason="x", + until=datetime.now(UTC) - timedelta(hours=1))) + await s.commit() + assert await is_blacklisted(s, "modelscope", "o/r", "m") is False +``` + +- [ ] **Step 2: Run** `uv run pytest tests/services/test_source_blacklist.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/services/source_blacklist.py`: +```python +"""Source/(source,repo,file) blacklist (Phase 3 SP2; doc §1.7). +Caller commits (service-layer convention).""" +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.db.models.source import SourceBlacklist + + +async def blacklist_file( + session: AsyncSession, *, source_id: str, repo_id: str, + filename: str, hours: int, reason: str, +) -> None: + session.add(SourceBlacklist( + source_id=source_id, repo_id=repo_id, filename=filename, + until=datetime.now(UTC) + timedelta(hours=hours), reason=reason)) + + +async def is_blacklisted( + session: AsyncSession, source_id: str, repo_id: str, filename: str +) -> bool: + row = await session.scalar( + select(SourceBlacklist.id).where( + SourceBlacklist.source_id == source_id, + SourceBlacklist.repo_id == repo_id, + SourceBlacklist.filename == filename, + SourceBlacklist.until > datetime.now(UTC)).limit(1)) + return row is not None +``` + +- [ ] **Step 4: Run** `uv run pytest tests/services/test_source_blacklist.py -v` → 2 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/services/source_blacklist.py tests/services/test_source_blacklist.py +git commit -m "feat(sp2): source blacklist service" +``` + +--- + +### Task 11: Planner `plan_task_sources` + +**Files:** Create `src/dlw/services/source_scheduler.py`; Test `tests/services/test_source_scheduler.py` + +- [ ] **Step 1: Write the failing test** — `tests/services/test_source_scheduler.py`: +```python +"""plan_task_sources: resolve→assign→persist + HF-authority gate (SP2).""" +from __future__ import annotations + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_scheduler import plan_task_sources +from dlw.sources.base import SourceFile, SourceManifest + +pytestmark = pytest.mark.slow + + +class _FakeDriver: + def __init__(self, sid, files, sha): + self.id = sid + self.provides_sha256 = sha + self._files = files + + async def resolve(self, repo_id, revision): + return SourceManifest(self.id, repo_id, revision, self._files, + has_lfs_sha256=any( + f.sha256 for f in self._files)) + + +class _FakeReg: + def __init__(self, drivers): + self._d = drivers + + def enabled_ids(self): + return list(self._d) + + def get(self, sid): + return self._d.get(sid) + + +class _IdResolver: + def resolve(self, source_id, hf_repo_id): + return hf_repo_id + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +def _files(): + return [SourceFile("model.safetensors", 200 * 1024 * 1024, "a" * 64, + "ref"), + SourceFile("config.json", 10, None, "ref2")] + + +async def test_plan_assigns_and_persists(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(task) + await s.flush() + for f in _files(): + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename=f.filename, + file_size=f.size, expected_sha256=f.sha256, + status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True), + "modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources( + s, task, registry=reg, resolver=_IdResolver(), + speeds={("huggingface"): 50.0, ("modelscope"): 900.0}, + chunk_min_mb=100) + await s.commit() + subs = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalars().all() + assert all(x.source_id in {"huggingface", "modelscope"} for x in subs) + big = next(x for x in subs if x.filename == "model.safetensors") + assert big.is_chunked is True + chunks = (await s.execute(select(SubtaskChunk).where( + SubtaskChunk.subtask_id == big.id))).scalars().all() + assert len(chunks) >= 2 + + +async def test_hf_absent_pauses_when_not_trusted(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + trust_non_hf_sha256=False) + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename="c.json", + file_size=10, expected_sha256=None, + status="pending")) + await s.commit() + reg = _FakeReg({"modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"modelscope": 900.0}, chunk_min_mb=100) + await s.commit() + assert task.status == "paused_external" + assert task.error_message == "no_sha256_authority" + + +async def test_no_sha_file_pinned_to_huggingface(factory): + """INVARIANT 12 (spec ruling 6a): a file with expected_sha256=None must + stay on huggingface even when a faster non-HF source covers it.""" + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, + filename="config.json", file_size=10, + expected_sha256=None, status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True), + "modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"huggingface": 1.0, + "modelscope": 9000.0}, + chunk_min_mb=100) + await s.commit() + sub = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalar_one() + assert sub.source_id == "huggingface" and sub.is_chunked is False + + +async def test_pin_modelscope_unreachable_pauses(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + source_strategy="pin_modelscope") + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename="m", + file_size=10, expected_sha256="a" * 64, + status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True)}) # no modelscope + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"huggingface": 50.0}, chunk_min_mb=100) + await s.commit() + assert task.status == "paused_external" + assert task.error_message == "pinned_source_unavailable" +``` + +- [ ] **Step 2: Run** `uv run pytest tests/services/test_source_scheduler.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/services/source_scheduler.py`: +```python +"""Task scheduling-phase source planner (Phase 3 SP2; doc §1.6/§1.8). +Caller commits.""" +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_combo import assign_files_lpt, solve_optimal_combo + +_CHUNK_BYTES = 64 * 1024 * 1024 # source-routing chunk granularity + + +def _strategy_filter(enabled: list[str], strategy: str, + blacklist: list[str]) -> tuple[list[str], str | None]: + """Apply task.source_strategy + task.source_blacklist (spec ruling 6e). + Returns (allowed_ids, pinned_or_None). pinned!=None means an explicit + single-source pin that must be honored (pause if unreachable).""" + allowed = [s for s in enabled if s not in blacklist] + if strategy == "auto_balance" or not strategy: + return allowed, None + if strategy == "fastest_only": + return allowed, None # combo will pick the single fastest + if strategy.startswith("pin_"): + pin = strategy.removeprefix("pin_") + return ([pin] if pin in allowed else []), pin + if strategy.startswith("list:"): + wanted = [x.strip() for x in strategy.removeprefix("list:").split(",")] + return [s for s in allowed if s in wanted], None + return allowed, None + + +async def plan_task_sources( + session: AsyncSession, task: DownloadTask, *, + registry: Any, resolver: Any, speeds: dict[str, float], + chunk_min_mb: int, overhead_pct: float = 2.0, +) -> None: + # 1. apply source_strategy / source_blacklist (spec ruling 6e) + allowed, pinned = _strategy_filter( + registry.enabled_ids(), task.source_strategy or "auto_balance", + list(task.source_blacklist or [])) + + # 2. resolve manifests across allowed sources + manifests: dict[str, Any] = {} + for sid in allowed: + drv = registry.get(sid) + src_repo = resolver.resolve(sid, task.repo_id) + if src_repo is None: + continue + m = await drv.resolve(src_repo, task.revision) + if m is not None: + manifests[sid] = (drv, m) + + if pinned is not None and pinned not in manifests: + task.status = "paused_external" + task.error_message = "pinned_source_unavailable" + return + + # 3. HF sha256 authority gate (INVARIANT 13) + hf_ok = "huggingface" in manifests + if not hf_ok and not task.trust_non_hf_sha256: + task.status = "paused_external" + task.error_message = "no_sha256_authority" + return + + # 4. candidates = covering sources with positive speed (spec ruling 6c) + candidates = {sid: speeds[sid] for sid in manifests + if sid in speeds and speeds[sid] > 0} + if not candidates: + task.status = "paused_external" + task.error_message = "no_source_speed" + return + subs = (await session.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalars().all() + sizes = {x.filename: (x.file_size or 0) for x in subs} + combo = solve_optimal_combo(candidates, sizes, overhead_pct=overhead_pct) + combo_speeds = {s: candidates[s] for s in combo} + + # 5. assign; INVARIANT 12 — files with no HF sha authority stay HF-only + assign = assign_files_lpt(sizes, combo_speeds) + hf_files: set[str] = set() + if "huggingface" in manifests: + hf_files = {f.filename for f in manifests["huggingface"][1].files} + chunk_min = chunk_min_mb * 1024 * 1024 + for sub in subs: + no_hf_authority = (sub.expected_sha256 is None + or sub.filename not in hf_files) + if no_hf_authority and not task.trust_non_hf_sha256: + if "huggingface" not in manifests: + task.status = "paused_external" + task.error_message = "no_sha256_authority" + return + sub.source_id = "huggingface" # single-source, no chunk-split + continue + sid = assign[sub.filename] + sub.source_id = sid + covering = [s for s in combo + if any(f.filename == sub.filename + for f in manifests[s][1].files)] + if (sub.file_size or 0) >= chunk_min and len(covering) >= 2: + sub.is_chunked = True + await _split_chunks(session, sub, sub.file_size, covering, + combo_speeds) + + +async def _split_chunks( + session: AsyncSession, sub: FileSubTask, size: int, + sources: list[str], speeds: dict[str, float], +) -> None: + total = sum(speeds[s] for s in sources) or 1.0 + offset = 0 + idx = 0 + for i, sid in enumerate(sources): + if i == len(sources) - 1: + length = size - offset + else: + portion = int(size * speeds[sid] / total) + length = max(_CHUNK_BYTES, + (portion // _CHUNK_BYTES) * _CHUNK_BYTES) + length = min(length, size - offset) + if length <= 0: + continue + session.add(SubtaskChunk( + subtask_id=sub.id, chunk_index=idx, byte_start=offset, + byte_end=offset + length - 1, source_id=sid, status="pending")) + offset += length + idx += 1 +``` + +- [ ] **Step 4: Run** `uv run pytest tests/services/test_source_scheduler.py -v` → 4 PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/services/source_scheduler.py tests/services/test_source_scheduler.py +git commit -m "feat(sp2): plan_task_sources (resolve/assign/chunk-split + HF gate)" +``` + +--- + +# Milestone M4 — Proxy + Executor + Lifespan + +### Task 12: Generalized source-proxy + +**Files:** Create `src/dlw/api/source_proxy.py`; Modify `src/dlw/main.py` (mount router); Test `tests/api/test_source_proxy.py` + +- [ ] **Step 1: Write the failing test** — `tests/api/test_source_proxy.py`: +```python +"""source-proxy routes to the assigned driver, INVARIANT 2 (SP2).""" +from __future__ import annotations + +import uuid + +import httpx +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from tests.conftest import make_app_with_state, register_test_executor + +pytestmark = pytest.mark.slow + +SECRET = "unit-secret" + + +@pytest.fixture(autouse=True) +def _env(monkeypatch): + monkeypatch.setenv("DLW_SYSTEM_JWT_SECRET", SECRET) + from dlw.config import get_settings + get_settings.cache_clear() + yield + get_settings.cache_clear() + + +@pytest.fixture +async def app_client(ephemeral_ca, engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.task import DownloadTask, FileSubTask + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + app = make_app_with_state(ephemeral_ca, enrollment_token="e") + + # fake registry on app.state: a driver that streams "HELLO" + class _D: + id = "modelscope" + + def download_url(self, file): + return "https://www.modelscope.cn/x" + + def auth_token(self, t): + from dlw.sources.base import SourceToken + return SourceToken(scheme="none") + + class _Reg: + def get(self, sid): + return _D() if sid == "modelscope" else None + + app.state.source_registry = _Reg() + # patch the proxy's outbound client to a MockTransport + import dlw.api.source_proxy as sp + + def _mk(_t): + return httpx.AsyncClient(transport=httpx.MockTransport( + lambda r: httpx.Response(200, content=b"HELLO", + headers={"Content-Length": "5"}))) + monkeypatch_target = sp + sp._make_source_client = _mk # type: ignore[attr-defined] + + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + yield app, c, f + + +async def test_proxy_streams_from_assigned_source(app_client): + app, client, f = app_client + from dlw.db.models.task import DownloadTask, FileSubTask + reg = await register_test_executor(client, enrollment_token="e") + async with f() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="downloading") + s.add(t) + await s.flush() + tok = uuid.uuid4() + sub = FileSubTask(task_id=t.id, tenant_id=1, filename="m", + file_size=5, status="assigned", + executor_id=reg["executor_id"], + executor_epoch=reg["epoch"], assignment_token=tok, + source_id="modelscope") + s.add(sub) + await s.commit() + sub_id = sub.id + from tests.conftest import executor_request_headers + h = {**executor_request_headers(reg), "X-Assignment-Token": str(tok)} + r = await client.get(f"/api/v1/source-proxy/subtask/{sub_id}", headers=h) + assert r.status_code == 200 + assert r.content == b"HELLO" +``` + +- [ ] **Step 2: Run** `uv run pytest tests/api/test_source_proxy.py -v` → FAIL. + +- [ ] **Step 3: Implement** — `src/dlw/api/source_proxy.py` (copy the W3b ownership chain from `src/dlw/api/hf_proxy.py`, swap URL building for driver dispatch): +```python +"""Generalized multi-source reverse-proxy (Phase 3 SP2). Mirrors the W3b +hf_proxy ownership chain; routes each subtask/chunk to its assigned +SourceDriver and injects that source's controller-side credential. The +source token NEVER leaves the controller (INVARIANT 2).""" +from __future__ import annotations + +import uuid + +import httpx +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.api.tasks import _session +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.config import get_settings +from dlw.db.models.executor import Executor +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.sources.base import SourceFile + +router = APIRouter(prefix="/api/v1/source-proxy", tags=["executors"]) + +_HDR_ALLOW = frozenset({ + "content-length", "content-range", "content-type", + "accept-ranges", "etag", +}) + + +def _make_source_client(timeout_seconds: int) -> httpx.AsyncClient: + """Test seam — monkeypatched to inject httpx.MockTransport.""" + return httpx.AsyncClient(follow_redirects=True, timeout=timeout_seconds) + + +@router.get("/subtask/{subtask_id}") +async def source_proxy_subtask( + subtask_id: uuid.UUID, + request: Request, + x_assignment_token: str = Header(..., alias="X-Assignment-Token"), + auth_ex: Executor = Depends(require_executor_jwt), + session: AsyncSession = Depends(_session), +) -> StreamingResponse: + sub = await session.get(FileSubTask, subtask_id) + if sub is None: + raise HTTPException(404, detail="subtask not found") + if sub.executor_id != auth_ex.id: + raise HTTPException(403, detail={"code": "NOT_YOUR_SUBTASK"}) + if sub.assignment_token is None or str(sub.assignment_token) != x_assignment_token: + raise HTTPException(409, detail={"code": "STALE_ASSIGNMENT"}) + if sub.executor_epoch != auth_ex.epoch: + raise HTTPException(409, detail={"code": "EPOCH_MISMATCH"}) + task = await session.get(DownloadTask, sub.task_id) + if task is None: + raise HTTPException(500, detail="parent task missing") + + settings = get_settings() + range_header = request.headers.get("Range") + + # which source? chunked subtask → resolve by Range start; else sub.source_id + source_id = sub.source_id + if sub.is_chunked and range_header and range_header.startswith("bytes="): + start = int(range_header.split("=", 1)[1].split("-", 1)[0]) + chunk = await session.scalar(select(SubtaskChunk).where( + SubtaskChunk.subtask_id == sub.id, + SubtaskChunk.byte_start <= start, + SubtaskChunk.byte_end >= start)) + if chunk is not None: + source_id = chunk.source_id + if source_id is None: + raise HTTPException(409, detail={"code": "SOURCE_UNASSIGNED"}) + + registry = request.app.state.source_registry + drv = registry.get(source_id) + if drv is None: + raise HTTPException(502, detail={"code": "SOURCE_UNAVAILABLE"}) + + src_file = SourceFile(filename=sub.filename, size=sub.file_size, + sha256=sub.expected_sha256, + download_ref=f"{task.repo_id}/resolve/" + f"{task.revision}/{sub.filename}") + url = drv.download_url(src_file) + tok = drv.auth_token(settings.hf_token) + headers: dict[str, str] = {} + if tok.scheme == "bearer" and tok.value: + headers["Authorization"] = f"Bearer {tok.value}" + if range_header: + headers["Range"] = range_header + + client = _make_source_client(settings.hf_proxy_timeout_seconds) + req = client.build_request("GET", url, headers=headers) + try: + resp = await client.send(req, stream=True) + except (httpx.TimeoutException, httpx.NetworkError) as e: + await client.aclose() + raise HTTPException(503, detail=f"source unreachable: {e}") from e + except BaseException: + await client.aclose() + raise + + fwd = {k: v for k, v in resp.headers.items() + if k.lower() in _HDR_ALLOW} + + async def _body(): + try: + async for chunk in resp.aiter_bytes(64 * 1024): + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse(_body(), status_code=resp.status_code, + headers=fwd) +``` +In `src/dlw/main.py` `create_app()` (after the existing router includes), add: +```python + from dlw.api.source_proxy import router as source_proxy_router + app.include_router(source_proxy_router) +``` + +- [ ] **Step 4: Run** `uv run pytest tests/api/test_source_proxy.py -v` → PASS. (If the `monkeypatch` seam in the fixture needs the `monkeypatch` fixture arg, add it to `app_client(ephemeral_ca, engine, monkeypatch)` and use `monkeypatch.setattr(sp, "_make_source_client", _mk)` — adjust during impl; the seam is `dlw.api.source_proxy._make_source_client`.) + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/api/source_proxy.py src/dlw/main.py tests/api/test_source_proxy.py +git commit -m "feat(sp2): generalized /source-proxy with per-source cred (INV 2)" +``` + +--- + +### Task 13: Executor `stream_source` + +**Files:** Modify `src/dlw/executor/client.py`, `src/dlw/executor/chunk_downloader.py`, `src/dlw/executor/downloader.py`, `tests/conftest.py`; Test `tests/executor/test_stream_source.py` + +- [ ] **Step 1: Write the failing test** — `tests/executor/test_stream_source.py`: +```python +"""ControllerClient.stream_source targets /source-proxy (Phase 3 SP2).""" +from __future__ import annotations + +import uuid + +import httpx +import pytest + +from dlw.executor.client import ControllerClient +from tests.conftest import make_fake_auth_state + + +async def test_stream_source_hits_source_proxy(tmp_path): + seen = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["path"] = request.url.path + seen["range"] = request.headers.get("Range") + return httpx.Response(200, content=b"DATA") + + c = ControllerClient( + "http://ctrl", + auth_state=make_fake_auth_state(tmp_path), + _transport=httpx.MockTransport(handler)) + sid = uuid.uuid4() + tok = uuid.uuid4() + async with c.stream_source(subtask_id=sid, assignment_token=tok, + range_header="bytes=0-3") as resp: + assert resp.status_code == 200 + body = b"" + async for b in resp.aiter_bytes(): + body += b + assert body == b"DATA" + assert seen["path"] == f"/api/v1/source-proxy/subtask/{sid}" + assert seen["range"] == "bytes=0-3" +``` + +- [ ] **Step 2: Run** `uv run pytest tests/executor/test_stream_source.py -v` → FAIL. + +- [ ] **Step 3: Implement** — in `src/dlw/executor/client.py`, add a method mirroring `stream_hf` exactly but with the source-proxy path (place directly after `stream_hf`): +```python + @asynccontextmanager + async def stream_source( + self, + *, + subtask_id: uuid.UUID, + assignment_token: uuid.UUID, + range_header: str | None = None, + ) -> AsyncIterator[httpx.Response]: + """SP2: stream a file/chunk from its assigned source via the + controller's generalized reverse-proxy. Same contract as stream_hf + (caller inspects resp.status_code; no raise_for_status).""" + headers = { + **self._auth_headers(), + "X-Assignment-Token": str(assignment_token), + } + if range_header: + headers["Range"] = range_header + async with self._make_client() as client: + async with client.stream( + "GET", + f"/api/v1/source-proxy/subtask/{subtask_id}", + headers=headers, + ) as resp: + yield resp +``` +In `src/dlw/executor/chunk_downloader.py`, replace the **two** `self._controller.stream_hf(` call sites (verified: lines ~74 in `_resolve_size`, ~128 in `_download_one_chunk`) with `self._controller.stream_source(` (identical signature). In `src/dlw/executor/downloader.py`, replace its single `self._controller.stream_hf(` call site (~line 71, `HfS3StreamDownloader` — the non-chunked single-file path) with `self._controller.stream_source(` too, so EVERY executor download goes through `/source-proxy` and honors the planner's `sub.source_id` (an HF-assigned subtask is just routed to the `huggingface` driver — equivalent to the old `/hf-proxy`). Nothing else in those files changes. + +**Chunk alignment (spec ruling 6b):** for an `is_chunked` subtask the executor must download one Range per `subtask_chunks` row, not its local `plan_chunks` split. Add to `DirectOffsetDownloader`: when the assignment indicates a chunked subtask, fetch the chunk rows from the controller and use their `byte_start/byte_end` as the chunk plan (each Range then maps to exactly one source in `source_proxy`). MINIMAL implementation for SP2: the controller exposes the chunk boundaries in the poll/assignment payload (a `chunks: [[start,end],...]` list when `is_chunked`); `DirectOffsetDownloader.download` uses those offsets instead of `plan_chunks(...)` when present. (If wiring the assignment payload is non-trivial, the implementer reports DONE_WITH_CONCERNS and the controller decides; the spec ruling 6b is the contract — every executor Range must align to one `subtask_chunks` row.) The existing sequential offset-order SHA256 in `_pass2_upload` is unchanged and remains the whole-file hash the W4 gate verifies. + +In `tests/conftest.py`, the shared test double `make_fake_controller_client._FakeControllerClient` currently defines only `stream_hf`. Add a `stream_source` method to it that mirrors `stream_hf` exactly but targets `/api/v1/source-proxy/subtask/{subtask_id}` (same `@asynccontextmanager`/MockTransport body, same params). This keeps `tests/executor/test_chunk_downloader.py` + `tests/executor/test_downloader.py` green after the swap. + +- [ ] **Step 4: Run** `uv run pytest tests/executor/test_stream_source.py tests/executor/test_chunk_downloader.py tests/executor/test_downloader.py -v` → PASS (the conftest `stream_source` addition keeps the existing downloader tests green; if any test asserts the literal `/hf-proxy` path, update that assertion to `/source-proxy`). + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/executor/client.py src/dlw/executor/chunk_downloader.py src/dlw/executor/downloader.py tests/conftest.py tests/executor/test_stream_source.py +git commit -m "feat(sp2): executor stream_source -> /source-proxy (all paths) + conftest fake" +``` + +--- + +### Task 14: Lifespan bootstrap + scheduling/rebalance loops + conftest + +**Files:** Modify `src/dlw/main.py`, `tests/conftest.py`, `tests/test_lifespan_state.py`; Test `tests/test_sp2_lifespan.py` + +- [ ] **Step 1: Write the failing test** — `tests/test_sp2_lifespan.py`: +```python +"""Real lifespan bootstraps source_registry + name_resolver (SP2; +SP1-regression-class: app.state used by routes MUST be set in lifespan).""" +from __future__ import annotations + +import pytest + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base + +pytestmark = pytest.mark.slow + + +async def test_lifespan_sets_source_state(engine, tmp_path, monkeypatch): + async with engine.begin() as c: + await c.run_sync(Base.metadata.create_all) + monkeypatch.setenv("DLW_AUTH_DEV_MODE", "true") + monkeypatch.setenv("DLW_CA_DIR", str(tmp_path / "ca")) + from dlw.config import get_settings + get_settings.cache_clear() + from dlw.main import create_app, lifespan + from dlw.sources.registry import SourceRegistry + app = create_app() + async with lifespan(app): + assert isinstance(app.state.source_registry, SourceRegistry) + assert app.state.name_resolver is not None + assert "huggingface" in app.state.source_registry.enabled_ids() + get_settings.cache_clear() + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) +``` + +- [ ] **Step 2: Run** `uv run pytest tests/test_sp2_lifespan.py -v` → FAIL. + +- [ ] **Step 3: Implement** — in `src/dlw/main.py` lifespan, in the **unconditional** SP1 block (right after `app.state.casbin = build_enforcer(grants=_grants)`), add: +```python + from dlw.sources.name_resolver import NameResolver + from dlw.sources.registry import load_registry + app.state.source_registry = load_registry( + _settings.sources_yaml_path, hf_token=_settings.hf_token) + app.state.name_resolver = NameResolver.from_file( + _settings.resolver_rules_path) +``` +Add two leader-gated loops mirroring SP1's `_quota_loop`/`quota_task_holder` exactly. After the `quota_task_holder` definition add: +```python + sched_task_holder: dict[str, asyncio.Task | None] = {"t": None} + rebalance_task_holder: dict[str, asyncio.Task | None] = {"t": None} + + async def _scheduling_loop() -> None: + from dlw.services.source_scheduler import run_scheduling_tick + while True: + try: + await asyncio.sleep(5) + async with factory() as session: + await run_scheduling_tick( + session, app.state.source_registry, + app.state.name_resolver, _gs()) + await session.commit() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("scheduling tick failed; retrying") + + async def _rebalance_loop() -> None: + from dlw.services.source_scheduler import run_rebalance_tick + while True: + try: + await asyncio.sleep(_gs().rebalance_interval_seconds) + async with factory() as session: + await run_rebalance_tick(session, _gs()) + await session.commit() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("rebalance tick failed; retrying") +``` +(`_gs` is the lifespan-local alias for `get_settings` — `from dlw.config import get_settings as _gs` at `main.py:57`; do NOT call a bare `get_settings()` here, it is not in scope and would `NameError`.) + +In `_on_active` (next to the existing `quota_task_holder["t"] = ...`): +```python + sched_task_holder["t"] = asyncio.create_task(_scheduling_loop()) + rebalance_task_holder["t"] = asyncio.create_task(_rebalance_loop()) +``` +In `_on_step_down`, after the existing quota-task cancel block, add the symmetric cancel for both new holders (mirror the exact `qt = ...; if qt is not None: qt.cancel(); try: await asyncio.wait_for(qt, timeout=2) ...; holder["t"] = None` block for `sched_task_holder` and `rebalance_task_holder`). +Add `run_scheduling_tick`/`run_rebalance_tick` to `src/dlw/services/source_scheduler.py`: +```python +async def run_scheduling_tick(session, registry, resolver, settings) -> None: + """Pick `pending` tasks; controller-side probe each source; plan; move + to claimable. Probe = one small ranged GET controller→source (spec + ruling 6d); fused with latest SourceSpeedSample EWMA history.""" + from dlw.db.models.source import SourceSpeedSample + from dlw.services.source_speed import ( + fuse_ewma, + pick_probe_size_bytes, + probe_source_speed, + ) + pend = (await session.execute(select(DownloadTask).where( + DownloadTask.status == "pending").limit(20))).scalars().all() + probe_bytes = pick_probe_size_bytes(probe_size_mb=settings.probe_size_mb) + for task in pend: + task.status = "scheduling" + speeds: dict[str, float] = {} + for sid in registry.enabled_ids(): + drv = registry.get(sid) + src_repo = resolver.resolve(sid, task.repo_id) + live = 0.0 + if src_repo is not None: + try: + m = await drv.resolve(src_repo, task.revision) + except Exception: + m = None + if m is not None and m.files: + probe_f = min(m.files, key=lambda f: f.size or 1 << 62) + live = await probe_source_speed( + drv, probe_f, probe_bytes=probe_bytes, + timeout_s=settings.probe_timeout_s, + hf_token=settings.hf_token) + hist = await session.scalar( + select(SourceSpeedSample.bytes_per_sec) + .where(SourceSpeedSample.source_id == sid) + .order_by(SourceSpeedSample.measured_at.desc()).limit(1)) + fused = fuse_ewma(live=live, hist=float(hist) if hist else None, + hist_weight=settings.probe_history_weight) + if live > 0: + session.add(SourceSpeedSample( + executor_id="controller", source_id=sid, + bytes_per_sec=live, sample_size=probe_bytes, + is_active_probe=True)) + speeds[sid] = fused if fused > 0 else 0.0 + await plan_task_sources( + session, task, registry=registry, resolver=resolver, + speeds=speeds, chunk_min_mb=settings.chunk_level_min_file_mb, + overhead_pct=settings.combo_overhead_per_source_pct) + if task.status == "scheduling": + task.status = "downloading" + + +async def run_rebalance_tick(session, settings) -> None: + """Reassign a degraded source's PENDING chunks to a healthy sibling + source on the same subtask (in-flight chunks untouched).""" + from sqlalchemy import text + from dlw.services.source_blacklist import active_blacklisted_sources + bad = await active_blacklisted_sources(session) + if not bad: + return + for sub_src in bad: + await session.execute(text( + "UPDATE subtask_chunks c SET source_id = (" + " SELECT source_id FROM subtask_chunks d " + " WHERE d.subtask_id=c.subtask_id AND d.source_id!=:bad " + " LIMIT 1) " + "WHERE c.source_id=:bad AND c.status='pending' " + "AND EXISTS (SELECT 1 FROM subtask_chunks e " + " WHERE e.subtask_id=c.subtask_id AND e.source_id!=:bad)" + ), {"bad": sub_src}) +``` +Add to `src/dlw/services/source_blacklist.py`: +```python +async def active_blacklisted_sources(session: AsyncSession) -> list[str]: + rows = await session.execute(select(SourceBlacklist.source_id).where( + SourceBlacklist.until > datetime.now(UTC)).distinct()) + return [r[0] for r in rows] +``` +In `tests/conftest.py` `make_app_with_state`, after the `app.state.casbin = ...` line, add: +```python + from dlw.sources.name_resolver import NameResolver + from dlw.sources.registry import load_registry + _s = app.state.settings + app.state.source_registry = load_registry( + _s.sources_yaml_path, hf_token=_s.hf_token) + app.state.name_resolver = NameResolver.from_file(_s.resolver_rules_path) +``` +(The default `config/sources.yaml`/`config/resolver-rules.yaml` exist from M1; tests run from repo root so the relative paths resolve.) Extend `tests/test_lifespan_state.py`'s existing test to also `assert app.state.source_registry is not None`. + +- [ ] **Step 4: Run** `uv run pytest tests/test_sp2_lifespan.py tests/test_lifespan_state.py -v` → PASS. + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/main.py src/dlw/services/source_scheduler.py src/dlw/services/source_blacklist.py tests/conftest.py tests/test_lifespan_state.py tests/test_sp2_lifespan.py +git commit -m "feat(sp2): lifespan registry/resolver bootstrap + scheduling/rebalance loops" +``` + +--- + +### Task 15: SHA256 authority gate on report + +**Files:** Modify `src/dlw/services/scheduler.py`; Test `tests/services/test_sha_authority.py` + +- [ ] **Step 1: Write the failing test** — `tests/services/test_sha_authority.py`: +```python +"""Non-HF completion verified vs HF expected_sha256 → blacklist on mismatch.""" +from __future__ import annotations + +import uuid + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SourceBlacklist +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.scheduler import complete_subtask + +pytestmark = pytest.mark.slow + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_non_hf_sha_mismatch_blacklists(factory): + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="downloading") + s.add(t) + await s.flush() + tok = uuid.uuid4() + sub = FileSubTask(task_id=t.id, tenant_id=1, filename="m", + file_size=4, expected_sha256="c" * 64, + status="assigned", assignment_token=tok, + source_id="modelscope") + s.add(sub) + await s.flush() + sid = sub.id + done, _ = await complete_subtask( + s, sid, final_status="succeeded", actual_sha256="d" * 64, + bytes_downloaded=4, error=None, assignment_token=tok) + await s.commit() + assert done.status == "failed" # existing sha-gate already flips this + bl = (await s.execute(select(SourceBlacklist).where( + SourceBlacklist.source_id == "modelscope"))).scalars().all() + assert len(bl) == 1 and bl[0].filename == "m" +``` + +- [ ] **Step 2: Run** `uv run pytest tests/services/test_sha_authority.py -v` → FAIL (no blacklist row written yet). + +- [ ] **Step 3: Implement** — in `src/dlw/services/scheduler.py` `complete_subtask`: the existing W4 sha256 gate flips `final_status` to `"failed"` on mismatch (~lines 173-182), then `sub.status = final_status` (~185) and `parent = await session.get(DownloadTask, sub.task_id, with_for_update=True)` (~line 192-194). Insert the blacklist write **immediately AFTER that `parent = await session.get(...)` line** (so it reuses the already-locked `parent` — no duplicate fetch) and **before** the siblings query (~line 195). Add a module-top import `from dlw.services.source_blacklist import blacklist_file` (no circular import: `source_blacklist` imports only models; `scheduler` is not scanned for this). Insert exactly: +```python + if (final_status == "failed" and sub.source_id + and sub.source_id != "huggingface" + and sub.expected_sha256 is not None + and actual_sha256 != sub.expected_sha256): + await blacklist_file( + session, source_id=sub.source_id, repo_id=parent.repo_id, + filename=sub.filename, hours=24, reason="sha_mismatch") +``` +(`parent` here is the locked row already fetched on the preceding line — do NOT add a second `session.get`. `sub.status` is already `"failed"`. Re-queue/HF-repin of the file is handled by the next scheduling pass — this hook is intentionally minimal: just the 24h blacklist row, exactly what `tests/e2e/test_multi_source.py::test_sha256_mismatch_blacklists_source` and Task 15's test assert. Note: the only new string literals added to the scanned `scheduler.py` are `"failed"`/`"huggingface"` inside an `if` condition — NOT a `status=`/`.status =` assignment — so `tools/lint_invariants.py` does not flag them; confirm with `python tools/lint_invariants.py`.) + +- [ ] **Step 4: Run** `uv run pytest tests/services/test_sha_authority.py -v` → PASS. Then `python tools/lint_invariants.py` → exit 0 (no new status literals added to scanned files; `"failed"`/`"huggingface"` are not status-kwarg literals flagged by the AST check — confirm). + +- [ ] **Step 5: Commit** +```bash +git add src/dlw/services/scheduler.py tests/services/test_sha_authority.py +git commit -m "feat(sp2): HF sha256 authority — non-HF mismatch blacklists source 24h" +``` + +--- + +# Milestone M5 — E2E + Docs + PR + +### Task 16: E2E-002 multi-source + +**Files:** Create `tests/e2e/test_multi_source.py` + +- [ ] **Step 1: Write the test** — `tests/e2e/test_multi_source.py`: +```python +"""E2E-002: auto_balance planning + HF-authority pause (Phase 3 SP2). + +End-to-end at the planner+DB level (no live mirrors): a task with HF + a +faster ModelScope-style fake source gets files assigned to the faster +source, and an HF-absent task without trust pauses (INVARIANT 13).""" +from __future__ import annotations + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_scheduler import plan_task_sources +from dlw.sources.base import SourceFile, SourceManifest + +pytestmark = pytest.mark.slow + + +class _Drv: + def __init__(self, sid, files): + self.id = sid + self.provides_sha256 = sid in ("huggingface", "hf_mirror") + self._f = files + + async def resolve(self, repo, rev): + return SourceManifest(self.id, repo, rev, self._f, + has_lfs_sha256=any(f.sha256 for f in self._f)) + + +class _Reg: + def __init__(self, d): + self._d = d + + def enabled_ids(self): + return list(self._d) + + def get(self, s): + return self._d.get(s) + + +class _Id: + def resolve(self, sid, repo): + return repo + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_auto_balance_prefers_fast_source(factory): + files = [SourceFile("a.safetensors", 50, "a" * 64, "r"), + SourceFile("b.safetensors", 50, "b" * 64, "r")] + reg = _Reg({"huggingface": _Drv("huggingface", files), + "modelscope": _Drv("modelscope", files)}) + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(t) + await s.flush() + for f in files: + s.add(FileSubTask(task_id=t.id, tenant_id=1, filename=f.filename, + file_size=f.size, expected_sha256=f.sha256, + status="pending")) + await s.commit() + await plan_task_sources(s, t, registry=reg, resolver=_Id(), + speeds={"huggingface": 50.0, + "modelscope": 5000.0}, + chunk_min_mb=100) + await s.commit() + subs = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == t.id))).scalars().all() + assert all(x.source_id == "modelscope" for x in subs) # HF too slow + + +async def test_hf_unavailable_pauses(factory): + files = [SourceFile("a", 10, None, "r")] + reg = _Reg({"modelscope": _Drv("modelscope", files)}) + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + trust_non_hf_sha256=False) + s.add(t) + await s.flush() + s.add(FileSubTask(task_id=t.id, tenant_id=1, filename="a", + file_size=10, status="pending")) + await s.commit() + await plan_task_sources(s, t, registry=reg, resolver=_Id(), + speeds={"modelscope": 900.0}, chunk_min_mb=100) + await s.commit() + assert t.status == "paused_external" + assert t.error_message == "no_sha256_authority" +``` + +- [ ] **Step 2: Run** `uv run pytest tests/e2e/test_multi_source.py -v` → 2 PASS (planner + gate already implemented; this is the integration acceptance gate — if it fails, fix the underlying code, not the test). + +- [ ] **Step 3: Commit** +```bash +git add tests/e2e/test_multi_source.py +git commit -m "test(sp2): E2E-002 auto_balance prefers fast source + HF-authority pause" +``` + +--- + +### Task 17: OpenAPI + operator doc + full CI gates + PR + +**Files:** Modify `api/openapi.yaml`; Create `docs/operator/multi-source.md` + +- [ ] **Step 1: Update `api/openapi.yaml`** — add the `GET /api/v1/source-proxy/subtask/{subtaskId}` operation (tag `executors`, `X-Assignment-Token` header + optional `Range`, 200 stream / 403 `NOT_YOUR_SUBTASK` / 409 `STALE_ASSIGNMENT`|`EPOCH_MISMATCH` / 502 `SOURCE_UNAVAILABLE` / 503), mirroring the existing `/api/v1/hf-proxy/subtask/{subtaskId}` operation's structure exactly. Add the `source_strategy`/`source_blacklist`/`trust_non_hf_sha256` properties to the `TaskCreate` request schema and `scheduling` to the task `status` enum if one is defined. Match existing indentation/style. + +- [ ] **Step 2: Run the exact OpenAPI CI commands** (no code-vs-yaml gate): +```bash +npx --yes @stoplight/spectral-cli@6 lint api/openapi.yaml --fail-severity=error +npx --yes @apidevtools/swagger-cli validate api/openapi.yaml +``` +Both must pass (spectral: 0 errors; warnings OK). Also `npx --yes yaml-lint api/openapi.yaml` style sanity — keep 2-space indent, no trailing whitespace (the `yamllint` CI job scans `api/`). + +- [ ] **Step 3: Create `docs/operator/multi-source.md`** (~100 lines): `config/sources.yaml` schema (id/enabled/driver/config/cost; supported drivers = huggingface/hf_mirror/modelscope; others ignored), `config/resolver-rules.yaml` (identity_organizations / aliases transform / per_model_overrides with examples), the SP2 `DLW_*` settings (probe/chunk/blacklist/rebalance), `source_strategy` task field values, the HF-sha256-authority rule (INVARIANT 13: HF-down → `paused_external` unless `trust_non_hf_sha256`), the 24h sha-mismatch blacklist, and that scheduling/rebalance run only on the active controller (leader-gated). Cross-ref `docs/v2.0/06-platform-and-ecosystem.md` §1 and `INVARIANTS` 11/12/13. + +- [ ] **Step 4: Full suite + all real CI gates locally**: +```bash +uv lock && uv sync --all-groups +uv run pytest tests/ --cov=src/dlw --cov-report=term-missing +uv run python -m pytest tools/test_lint_invariants.py -v +python tools/lint_invariants.py +python tools/lint_no_direct_status_write.py +``` +All green. (uv.lock already committed in Task 1; re-run `uv lock` only if drift.) Confirm `python tools/lint_invariants.py` exits 0 — SP2 added no status literals to the 3 scanned files (chunk statuses live in `source_scheduler.py`/`source_proxy.py`; `"scheduling"` is already in `VALID_TASK_STATUS`). + +- [ ] **Step 5: Commit + push + PR** +```bash +git add api/openapi.yaml docs/operator/multi-source.md +git commit -m "docs(sp2): OpenAPI source-proxy op + operator multi-source guide" +git push -u origin feat/phase-3-sp2-multi-source +gh pr create --title "Phase 3 SP2 — Multi-source (SourceDriver + NameResolver + LPT/chunk routing)" --body "$(cat <<'EOF' +## Summary +- SourceDriver Protocol + HF/hf_mirror/ModelScope drivers + sources.yaml registry + NameResolver (resolver-rules.yaml). +- Leader-gated scheduling loop: resolve → optimal-combo → LPT file→source + chunk-split (≥100MB, ≥2 sources) → persist source_id/subtask_chunks. HF sha256 authority (INV 11/12/13): HF-down→paused_external unless trust flag; non-HF mismatch → 24h source blacklist. +- Generalized /api/v1/source-proxy (per-source cred stays controller-side, INV 2); executor stream_source; minimal leader-gated rebalance of degraded sources' pending chunks. +- Additive migration (3 tables + task/subtask source columns). Phase 3 sub-project 2 of 4 (SP1 merged #15). + +## Test plan +- [ ] full `uv run pytest` green incl. E2E-002 `tests/e2e/test_multi_source.py` +- [ ] invariant_lint / openapi(spectral+swagger-cli) / yamllint CI gates green +- [ ] alembic up/down/up clean; uv.lock committed (pyyaml) + +Spec: docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md +Plan: docs/superpowers/plans/2026-05-19-phase-3-sp2-multi-source.md + +🤖 Generated with [Claude Code](https://claude.com/claude-code) +EOF +)" +``` + +--- + +## Self-Review (completed during planning + to be re-checked by 2 pre-execution reviewers) + +**Spec coverage:** SourceDriver/dataclasses→T2; HF/hf_mirror→T3; ModelScope→T4; registry/sources.yaml→T5; NameResolver/resolver-rules→T6; models+migration+TaskCreate fields→T7; speed EWMA→T8; LPT+combo→T9; blacklist→T10; planner+HF-authority gate+chunk-split→T11; source-proxy+INV2→T12; executor stream_source→T13; lifespan bootstrap+scheduling/rebalance loops+conftest+test_lifespan_state→T14; sha256-authority-on-report→T15; E2E-002→T16; OpenAPI+operator doc+CI+PR→T17. Deferred items (wisemodel/opencsg/plugin, Phase-B LP, UI, incremental=SP3, CLI=SP4) are explicitly out per spec §1.3/banner. + +**Placeholder scan:** every code step has complete code. The one bounded judgement note (T12 Step 4 monkeypatch-seam fixture-arg detail) names the exact seam (`dlw.api.source_proxy._make_source_client`) and the fix; not a placeholder. `` = alembic-generated hash (intentional). + +**Type/name consistency:** `SourceDriver`/`SourceManifest`/`SourceFile`/`SourceToken`/`SourceHealth` identical T2↔T3/4↔T5↔T11↔T12. `download_url(file)`+`auth_token(tenant_hf_token)` consistent T2↔T3/4↔T12. `load_registry(path, *, hf_token)`→`SourceRegistry.enabled_ids()/get()` consistent T5↔T11↔T12↔T14. `NameResolver.from_file(path)`/`.resolve(source_id, hf_repo_id)` consistent T6↔T11↔T14. `assign_files_lpt(files,speeds)`/`solve_optimal_combo(speeds,files,*,overhead_pct)` consistent T9↔T11. `plan_task_sources(session,task,*,registry,resolver,speeds,chunk_min_mb,overhead_pct)` consistent T11↔T14↔T16. `blacklist_file(...)`/`is_blacklisted(...)`/`active_blacklisted_sources(...)` consistent T10↔T14↔T15. `stream_source(*,subtask_id,assignment_token,range_header)` mirrors existing `stream_hf` T13↔T12. `SubtaskChunk`/`SourceSpeedSample`/`SourceBlacklist` columns consistent T7↔T10↔T11↔T14↔T15. + +## References +- Spec: `docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md` +- Design doc: `docs/v2.0/06-platform-and-ecosystem.md` §1; Invariants 11/12/13. +- Code anchors: `src/dlw/services/hf_metadata.py` (`list_repo_tree`/`RepoFile`), `src/dlw/api/hf_proxy.py` (W3b ownership chain copied by source_proxy), `src/dlw/executor/chunk_downloader.py`/`client.py` (`stream_hf`→`stream_source`), `src/dlw/services/scheduler.py` `complete_subtask` (W4 sha gate), `tools/lint_invariants.py` (`scheduling` already in `VALID_TASK_STATUS`; only 3 files scanned), SP1's `main.py` `_quota_loop`/`make_app_with_state`/`test_lifespan_state` patterns, alembic head `a4bed702cdb3`. +- Branch `feat/phase-3-sp2-multi-source` off `main` (`fa08e6d`), spec `ccdb9e8`. diff --git a/docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md b/docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md new file mode 100644 index 0000000..dc857f4 --- /dev/null +++ b/docs/superpowers/specs/2026-05-19-phase-3-sp2-multi-source-design.md @@ -0,0 +1,316 @@ +# Phase 3 SP2 — Multi-Source (SourceDriver + NameResolver + speed-test + LPT/chunk routing) Design + +> **Status:** Draft (brainstormed 2026-05-19). +> **Companion plan:** `docs/superpowers/plans/2026-05-19-phase-3-sp2-multi-source.md` (writing-plans, after spec approval). +> **Roadmap source:** `docs/v2.0/08-mvp-roadmap.md` §3 Phase 3 Week 2 ("多源"): SourceDriver 抽象 + HF + hf-mirror; ModelScope + NameResolver; 测速 + LPT routing; chunk-level routing + 局部重平衡. §3.5 exit: `U-SRC-*`/`I-SRC-*`, `E2E-002` (多源 auto_balance), 多源测速 5×4 ≤8s, LPT 加速 ≥2x. +> **Phase 3 decomposition:** SP1 multi-tenancy = merged (PR #15, `fa08e6d`). **This is SP2 (2nd of 4).** SP3 incremental download, SP4 CLI/SDK follow. +> **Design source:** `docs/v2.0/06-platform-and-ecosystem.md` §1 (multi-source, the authoritative section). +> **Invariant source:** `docs/v2.0/INVARIANTS.md` rows 11 (HF=SHA256 truth), 12 (cross-source verify→24h blacklist), 13 (HF unavailable→default refuse). Reuses Invariant 8 (tenant scoping, SP1). +> **Closes:** the v2.0 multi-source baseline (§1.1–§1.9 IN-scope subset). Phase-B adaptive optimizer (§1.8 cont.) stays v2.1 per the doc's own framing. + +> **⚠️ Scope decisions (authoritative — supersede any broader reading of doc 06 §1):** +> 1. **Drivers: `huggingface`, `hf_mirror`, `modelscope` only.** wisemodel/opencsg/s3_mirror + the §1.10 plugin loader are **deferred** (⚙️ default-off in the doc). +> 2. **LPT = "size-descending greedy heuristic"**, NOT a bounded-optimal algorithm (doc OR-V21-04). No LP-relaxation slow-path (v2.1). +> 3. **Rebalance is minimal**: a leader-gated 60s loop reassigns a degraded source's *pending* (not-started) chunks to healthy sources. Skipped-source *recovery* re-admission and Phase-B continuous LP recalibration are **deferred** (v2.1, doc §1.8 cont.). +> 4. **SHA256 authority = HF only** (INVARIANT 11/12/13). `trust_non_hf_sha256` is honored as a task flag (HF-down → still `paused_external` unless set); its **admin approval workflow is deferred to Phase 4** (doc 04 §5/§8). +> 5. **Out:** UI source-allocation view (§1.11, frontend), incremental download (§2 = SP3), CLI/SDK (§5 = SP4), webhook/MLflow/operator/HF-cache (§4, Phase 4), refcount/global-dedup (§3.1, SP3), cost-knob optimizer (`estimate_cost` stays on the Protocol, unused by SP2). +> 6. **Pre-execution-review correctness rulings (2-reviewer pass 2026-05-19, authoritative):** +> a. **INVARIANT 12 — no-sha files stay HF.** Any `FileSubTask` with `expected_sha256 IS NULL` (HF non-LFS small files) OR not covered by the HF manifest: the planner pins it to `source_id="huggingface"`, never chunk-splits it, never routes it to a non-HF source — unless `task.trust_non_hf_sha256`. If HF lacks the file entirely and not trusted → task `paused_external`/`no_sha256_authority`. (A non-HF source has no sha truth for it; the W4 gate can't verify a `None` expected.) +> b. **Chunk alignment.** For `is_chunked` subtasks the executor downloads **one Range per `subtask_chunks` row** (using the row's `byte_start/byte_end`), NOT its local `plan_chunks` split — so every Range maps to exactly one source. The existing sequential offset-order SHA256 in `_pass2_upload` then yields the correct whole-file hash, verified by the existing W4 gate against HF `expected_sha256` (this *is* the "post-merge whole-file rehash"; no separate rehash needed once boundaries align). +> c. **Zero/all-zero speed.** `plan_task_sources` filters `candidates` to `speed > 0`; empty → `paused_external`/`no_source_speed`. No division by zero. +> d. **Active speed-test = controller-side probe** (controller does a small ranged GET per source via the driver, writes `SourceSpeedSample`). Per-executor probe-through-proxy (doc §1.8 阶段A full N×M matrix) is **deferred to v2.1**. Passive completion samples + EWMA still apply. The §3.5-exit "≥2x"/"5×4 ≤8s" targets are interpreted against controller-relative source speed. +> e. **`source_strategy`/`source_blacklist` ARE enforced** by the planner: `pin_huggingface`/`pin_modelscope`/`list:a,b`/`fastest_only`/`auto_balance` filter `enabled_ids()`; `task.source_blacklist` source IDs removed; explicit-pin source unreachable → `paused_external`. +> f. **Blacklist transitions: sha-mismatch (24h) ONLY in SP2.** 5xx×3-degrade and health-timeout transitions are **deferred to v2.1** (doc §1.7). Minimal rebalance acts on the sha-mismatch blacklist table (degraded source's pending chunks reassigned). Spec §1.2/§3.7/§7's 5xx/health rows are v2.1. +> g. **HF token parity.** `source_proxy` injects `settings.hf_token` for the HF/hf_mirror drivers — identical to the existing W3b `src/dlw/api/hf_proxy.py:75` behavior. True per-tenant HF token is a *pre-existing* gap (W3b never implemented SP1's aspirational claim); **out of SP2 scope**, tracked for a later sub-project. Not a regression. +> 7. **Known limitation (final-review HIGH, accepted for SP2 — recorded per reviewer guidance):** ruling 6b's executor-Range↔`subtask_chunks`-row alignment is **asserted but not wired** (the assignment payload doesn't carry chunk boundaries; the executor still splits by its local `chunk_size_bytes`). Consequence is **bounded and sha-safe**: for a file genuinely split across *different* sources, an executor Range spanning two routing-chunks is fetched from one source → the whole-file SHA256 W4 gate catches the mismatch → subtask `failed` + that non-HF source blacklisted 24h + HF-refetch. No silent corruption, no security/INVARIANT hole; single-source (incl. all-HF) chunked downloads are unaffected. The only impact is that multi-source *chunk-level* acceleration may fall back to safe HF-refetch instead of succeeding, so the §3.5 "≥2x" target applies to file-level LPT (the dominant lever), not chunk-level, in SP2. Full per-chunk-row Range alignment (chunk boundaries in the poll payload) is **deferred to a follow-up sub-project**. +> The companion plan embeds all of these; it is the execution source of truth. + +--- + +## 1. Goal & Scope + +### 1.1 Goal + +Make a task downloadable from multiple mirror sources in parallel, choosing the fastest combination per the executing fleet, while HuggingFace remains the cryptographic source of truth. + +**Mechanism.** Today a task is single-source HF: `dlw.services.hf_metadata.list_repo_tree(repo_id, revision)` enumerates files controller-side; the executor streams every file's bytes through the W3b reverse-proxy `GET /api/v1/hf-proxy/subtask/{id}` (which reconstructs the HF URL from `task.repo_id/revision + sub.filename` and injects the HF token). SP2 introduces a `SourceDriver` abstraction (the existing HF path becomes the `huggingface` driver), a `NameResolver` (the same model has different IDs per source), a **task `scheduling` phase** that resolves manifests across enabled sources, speed-probes each `(executor, source)`, computes a Longest-Processing-Time greedy file→source assignment (plus chunk-level split for big files), and a generalized `/api/v1/source-proxy` that streams each subtask/chunk from its assigned source's driver with that source's controller-side credential. After download the controller verifies every file's sha256 against HF's authoritative value; a mismatch blacklists `(source, repo, filename)` for 24h and re-fetches HF-only. + +After SP2, a single-source-HF deployment still works: with only `huggingface` enabled (or only one source covering a repo), the scheduler degrades to single-source and behavior is identical to today. + +### 1.2 In scope + +| Item | Where | +|---|---| +| `SourceDriver` Protocol + `SourceManifest`/`SourceFile`/`SourceToken`/`SourceHealth` | `src/dlw/sources/base.py` (new) | +| `huggingface` driver (refactor existing HF enumerate+stream behind the Protocol) | `src/dlw/sources/huggingface.py` (new); `hf_metadata.py` reused internally | +| `hf_mirror` driver (HF-compatible, base-URL swap, no token, auto-skip gated) | `src/dlw/sources/hf_mirror.py` (new) | +| `modelscope` driver (own API + name mapping; no sha256) | `src/dlw/sources/modelscope.py` (new) | +| Source registry — load `config/sources.yaml` → enabled `{id: driver}` | `src/dlw/sources/registry.py` (new) | +| `NameResolver` 3-tier (identity-orgs / alias-rules / API-search, 24h cache) | `src/dlw/sources/name_resolver.py` (new); `config/resolver-rules.yaml` (new) | +| Scheduling-phase planner: resolve→probe→LPT/chunk-plan→persist | `src/dlw/services/source_scheduler.py` (new) | +| Speed probe + EWMA fusion + optimal-combo selection | `src/dlw/services/source_speed.py` (new) | +| Blacklist (5xx degrade, sha-mismatch 24h, health-timeout) | `src/dlw/services/source_blacklist.py` (new) | +| Generalized multi-source reverse-proxy | `src/dlw/api/source_proxy.py` (new) | +| `SubtaskChunk` + `SourceSpeedSample` + `SourceBlacklist` models + 1 migration | `src/dlw/db/models/source.py` (new), `src/dlw/alembic/versions/_p3sp2_multi_source.py` (new) | +| `download_tasks.source_strategy`/`source_blacklist`/`trust_non_hf_sha256`; `file_subtasks.source_id` | same migration | +| Task scheduling-phase wiring + chunk-aware claim + HF-sha256 authority gate | `src/dlw/services/task_service.py`, `src/dlw/services/scheduler.py` | +| Executor chunk downloader → `/source-proxy`, per-chunk range | `src/dlw/executor/chunk_downloader.py`, `src/dlw/executor/client.py` | +| Lifespan: unconditional registry/resolver bootstrap + leader-gated rebalance loop | `src/dlw/main.py` | +| Config: source/probe/blacklist/rebalance settings + yaml paths | `src/dlw/config.py` | +| `source_strategy`/`source_blacklist`/`trust_non_hf_sha256` on `TaskCreate` | `src/dlw/api/tasks.py`, `src/dlw/schemas/task.py` | +| Operator note: sources.yaml / resolver-rules.yaml / multi-source ops | `docs/operator/multi-source.md` (new) | + +### 1.3 Non-goals (deferred — explicit) + +| Item | Where | +|---|---| +| wisemodel / opencsg / s3_mirror drivers + §1.10 plugin loader | v2.1+ (⚙️ default-off in doc 06 §1.2) | +| Phase-B continuous LP recalibration; skipped-source recovery re-admission | v2.1 (doc 06 §1.8 cont. — "v2.0 反应式简化版... v2.1 升级"); SP2 keeps only the minimal degraded→reassign-pending rebalance | +| `trust_non_hf_sha256` admin approval workflow | Phase 4 (doc 04 §5/§8) — SP2 honors the boolean only | +| UI source-allocation view (§1.11) | frontend sub-project | +| Incremental/diff download (§2), global dedup/refcount (§3.1) | **SP3** | +| CLI `dlw` / Python SDK (§5) | **SP4** | +| Webhook / MLflow / K8s Operator / HF-cache (§4) | Phase 4 | +| `estimate_cost` cost-knob optimization (05 §8) | Protocol method exists; no optimizer in SP2 | +| BLAKE3 streaming hash for multi-source chunk mode | v2.2 (doc §9) — SP2 uses whole-file SHA256 re-scan after chunk merge | + +--- + +## 2. Tech Stack Additions + +| Dep | Why | Notes | +|---|---|---| +| `modelscope` SDK — **NOT added** | ModelScope driver uses raw `httpx` against its documented REST API (doc §1.9.3) | avoids a heavy SDK dep; `httpx` already present | +| `pyyaml>=6,<7` (runtime) | parse `sources.yaml` / `resolver-rules.yaml` | small, ubiquitous; not currently a direct dep — add to `pyproject.toml` + `uv lock` | + +Reused (no new dep): `huggingface_hub` (HF + hf_mirror drivers — hf_mirror = `HfApi(endpoint="https://hf-mirror.com")`), `httpx` (ModelScope + proxy + probe), SQLAlchemy async, FastAPI, structlog, SP1's `Principal`/`require_perm`/`tenant_filtered`/casbin. + +**One alembic migration**, `down_revision = "a4bed702cdb3"` (SP1 head). No new CI jobs. The real CI gates (verified in SP1): `pytest` (`uv sync --all-groups`, uv 0.11.9), `invariant_lint` (`tools/lint_invariants.py` AST-scans `api/tasks.py`/`services/task_service.py`/`services/scheduler.py` for invalid status literals — SP2 adds the `scheduling` task status and `subtask_chunks` statuses; **`"scheduling"` MUST be added to `tools/lint_invariants.py`'s `VALID_TASK_STATUS`, and chunk-status literals must live in `source_scheduler.py`/`source_proxy.py` which are NOT scanned, OR a chunk-status set added — confirm by running `python tools/lint_invariants.py`**), `openapi` (spectral `--fail-severity=error` + swagger-cli), `yamllint` (`api/` + note: `config/*.yaml` is **not** in the yamllint scan path `deploy/ api/`, so sources.yaml/resolver-rules.yaml are not CI-yamllinted — keep them valid anyway). `ruff`/`mypy` are local-only (not CI). + +--- + +## 3. Components + +### 3.1 `src/dlw/sources/base.py` — the Protocol + +```python +class SourceDriver(Protocol): + id: str + domain: str + provides_sha256: bool # True only for huggingface / hf_mirror + + async def resolve(self, repo_id: str, revision: str + ) -> SourceManifest | None: ... + async def download_range(self, file: SourceFile, + byte_range: tuple[int, int] | None + ) -> AsyncIterator[bytes]: ... + async def health_check(self) -> SourceHealth: ... + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: ... +``` + +`SourceManifest(source_id, repo_id_in_source, revision_in_source, files: list[SourceFile], has_lfs_sha256)`; `SourceFile(filename, size, sha256, download_ref)` (`filename` normalized to HF-style path so cross-source files key identically); `SourceHealth(ok: bool, latency_ms: float)`; `SourceToken` (opaque per-source cred handle, resolved controller-side — never serialized to the executor; INVARIANT 2). `resolve()` returns `None` if the source doesn't cover `(repo, revision)` — not an error. + +### 3.2 Drivers (`src/dlw/sources/{huggingface,hf_mirror,modelscope}.py`) + +- **`huggingface`**: `resolve` wraps the existing `hf_metadata.list_repo_tree` (reused, not rewritten) → `SourceManifest(has_lfs_sha256 from LFS sha)`. `download_range` builds `{hf_endpoint}/{repo}/resolve/{rev}/{filename}` with Range + `Authorization: Bearer ` (per-tenant via `task.tenant_id` — SP1 made this real). `provides_sha256=True`. +- **`hf_mirror`**: identical protocol, `endpoint="https://hf-mirror.com"`, **no token**; if `resolve` hits 401/403 (gated) → return `None` (auto-skip, doc §1.9.2). `provides_sha256=True` (mirror passes HF LFS sha). +- **`modelscope`**: raw `httpx`. `resolve` → `GET {base}/api/v1/models/{ms_repo}/repo?Revision={rev}` (`ms_repo` from NameResolver); files have **no sha256** (`provides_sha256=False`, `SourceFile.sha256=None`). `download_range` → `GET {base}/api/v1/models/{ms_repo}/repo?Revision={rev}&FilePath={filename}` + Range. + +### 3.3 `registry.py` + `config/sources.yaml` + +`load_registry(path) -> SourceRegistry` parses the doc §1.12 `sources.yaml` (subset: `id/enabled/driver/config/cost_per_gb_egress`), instantiates only `enabled` drivers among the 3 supported, exposes `enabled_ids()`, `get(id)`, and `regional_defaults`. Bootstrapped **unconditionally** in `main.py` lifespan into `app.state.source_registry` (mirrors SP1's `app.state.casbin`/`settings` — the SP1 final review proved state used by request paths MUST be set in lifespan, not only test fixtures; SP2 adds a `test_lifespan_state`-style assertion). + +### 3.4 `name_resolver.py` + `config/resolver-rules.yaml` + +`NameResolver.resolve(source_id, hf_repo_id) -> str | None`: (1) identity if org ∈ `identity_organizations` or `source_id == huggingface`; (2) `aliases`/`per_model_overrides` rules (org swap + `transform` template, doc §1.5); (3) source search-API fallback, result cached 24h (in-memory TTL dict; persistence deferred). Miss → `None` → that source skipped for this repo (not fatal). Bootstrapped into `app.state.name_resolver` alongside the registry. + +### 3.5 `services/source_scheduler.py` — the planner (task `scheduling` phase) + +**Trigger:** a **leader-gated scheduling loop** (new, reusing SP1's `_quota_loop`/sweep leader-gating in `main.py` — runs only on the active controller) polls `pending` tasks, transitions `pending → scheduling`, calls `plan_task_sources`, then moves the task into the existing claimable state (or `paused_external` on the HF-authority gate). `create_task` stays fast (no inline resolve/probe). `async def plan_task_sources(session, task) -> None`: + +1. **Resolve**: for each enabled source permitted by `task.source_strategy`/`source_blacklist` (and tenant policy), `NameResolver` → source repo id → `driver.resolve()`. Collect `{source_id: SourceManifest}`. **HF authority gate**: if `huggingface` manifest absent/unreachable and `not task.trust_non_hf_sha256` → task → `paused_external`, `last_error="no_sha256_authority"` (INVARIANT 13), return. +2. **Probe**: `source_speed.probe_matrix(eligible_executors, candidate_sources, a real ~probe_size_mb file)` → `{(exec,src): bytes/s}`, fused with `source_speed_samples` EWMA (α=0.3, live weight 0.7). Soft-deadline `probe_timeout_s`. +3. **Combine**: `_solve_optimal_combination` (doc §1.8) — evaluate fastest-K subsets with a +2%/extra-source overhead penalty; pick min-ETA combo. +4. **Assign**: `assign_files_lpt(files, combo_speeds)` (size-desc greedy, doc §1.6) sets `file_subtasks.source_id`. For each file `≥ chunk_level_min_file_mb` with ≥2 covering sources: split into `subtask_chunks` rows (`byte_start/byte_end/source_id/status='pending'`, speed-proportional, chunk-size-aligned) and mark the subtask `is_chunked`. +5. Persist; task → the existing claimable state. Files whose only source is non-HF and HF lacks sha256 → flagged "no multi-source acceleration" (single-source HF, doc §1.13 risk 1). + +### 3.6 `api/source_proxy.py` — generalized reverse-proxy + +`GET /api/v1/source-proxy/subtask/{subtask_id}` (+ `Range`): same W3b ownership chain (`require_executor_jwt` + assignment_token + epoch + confused-deputy guard — copied from `hf_proxy.py`), then: load `sub.source_id` (or, for a chunked subtask, the chunk's `source_id` from the `Range`→chunk lookup), get the driver from `app.state.source_registry`, resolve that source's controller-side token, `async for bytes in driver.download_range(file, range)` → `StreamingResponse` (header allowlist preserved). The W3b `/hf-proxy` route is **kept** (back-compat for any single-source path) but task downloads now target `/source-proxy`. INVARIANT 2: the source token never leaves the controller. + +### 3.7 Blacklist & failure (`services/source_blacklist.py`) + +`SourceBlacklist` table `(scope, source_id, repo_id, filename, until, reason)`. Transitions (doc §1.7): 5xx×3 on `(executor,source)` → degraded 5min (exp→30min, in-memory + sample table); sha256 mismatch on completion → `(source_id, repo_id, filename)` row, `until=now+24h`, that file re-planned HF-only; `health_check` >30s → source globally degraded until next OK probe. The scheduler/proxy consult the blacklist before assigning/streaming. + +### 3.8 `main.py` lifespan + leader-gated rebalance + +Unconditional (with SP1's settings/casbin block): `app.state.source_registry = load_registry(...)`, `app.state.name_resolver = NameResolver(...)`. Leader-gated (extend SP1's `_quota_loop` pattern with **two** holders — `scheduling_task_holder`, `rebalance_task_holder` — same `_on_active`/`_on_step_down` cancel wiring): `_scheduling_loop` (drives `pending → scheduling → plan_task_sources →` claimable, §3.5) and `_rebalance_loop` every `rebalance_interval_seconds` — for `downloading` tasks, detect degraded `(exec,source)` (probe<30% of plan, or 5xx-degraded), `UPDATE subtask_chunks SET source_id=, status='pending' WHERE source_id= AND status='pending'` via `FOR UPDATE SKIP LOCKED` (doc §1.8). In-flight chunks not interrupted; skipped-source recovery deferred. + +### 3.9 Config (`config.py`) + +```python + # Phase 3 SP2 — multi-source + sources_yaml_path: str = Field(default="config/sources.yaml") + resolver_rules_path: str = Field(default="config/resolver-rules.yaml") + probe_size_mb: int = Field(default=32, ge=1, le=256) + probe_timeout_s: float = Field(default=8.0, ge=1.0, le=60.0) + probe_history_weight: float = Field(default=0.3, ge=0.0, le=1.0) + combo_overhead_per_source_pct: float = Field(default=2.0, ge=0.0, le=50.0) + chunk_level_min_file_mb: int = Field(default=100, ge=1) + speed_ewma_alpha: float = Field(default=0.3, ge=0.0, le=1.0) + blacklist_5xx_count: int = Field(default=3, ge=1) + blacklist_minutes: int = Field(default=5, ge=1) + blacklist_max_minutes: int = Field(default=30, ge=1) + sha_mismatch_blacklist_hours: int = Field(default=24, ge=1) + rebalance_interval_seconds: float = Field(default=60.0, ge=5.0, le=600.0) + degradation_trigger_threshold: float = Field(default=0.3, ge=0.0, le=1.0) +``` + +--- + +## 4. Approaches Considered + +- **A — Driver-registry + scheduling-phase planner (chosen).** Central controller plans source→file/chunk at task `scheduling`; the generalized proxy streams per-assignment. Existing HF path becomes one driver (smallest blast radius); each unit (driver, resolver, planner, proxy, blacklist) testable in isolation with `httpx.MockTransport`/fake drivers; INVARIANT 2/11/13 each enforced in exactly one place. +- **B — Executor-side source selection.** Executor probes/picks its own source per file. Less controller coordination but: breaks central LPT/quota/blacklist consistency, and the executor would need source credentials → violates INVARIANT 2. Rejected. +- **C — Literal full doc §1 (6 drivers + plugin loader + Phase-B LP optimizer).** Matches the doc verbatim but is 2–3 milestones, most of it explicitly v2.1 in the doc itself. Rejected (YAGNI/scope; §1.3 defers it). + +--- + +## 5. Schema Changes + +One migration `_p3sp2_multi_source`, `down_revision = "a4bed702cdb3"`. + +**Altered:** +```sql +ALTER TABLE download_tasks ADD COLUMN source_strategy VARCHAR(32) NOT NULL DEFAULT 'auto_balance'; +ALTER TABLE download_tasks ADD COLUMN source_blacklist JSONB NOT NULL DEFAULT '[]'; +ALTER TABLE download_tasks ADD COLUMN trust_non_hf_sha256 BOOLEAN NOT NULL DEFAULT FALSE; +ALTER TABLE file_subtasks ADD COLUMN source_id VARCHAR(32); -- nullable: filled at scheduling +ALTER TABLE file_subtasks ADD COLUMN is_chunked BOOLEAN NOT NULL DEFAULT FALSE; +``` +**New tables** (doc §1.4/§1.7): +```sql +CREATE TABLE subtask_chunks ( + id BIGSERIAL PRIMARY KEY, + subtask_id UUID NOT NULL REFERENCES file_subtasks(id) ON DELETE CASCADE, + chunk_index INT NOT NULL, + byte_start BIGINT NOT NULL, + byte_end BIGINT NOT NULL, -- inclusive + source_id VARCHAR(32) NOT NULL, + status VARCHAR(16) NOT NULL, -- pending|downloading|done|failed + sha256_partial VARCHAR(64), + bytes_done BIGINT NOT NULL DEFAULT 0, + UNIQUE (subtask_id, chunk_index) +); +CREATE INDEX idx_chunk_sub_status ON subtask_chunks(subtask_id, status); +CREATE TABLE source_speed_samples ( + id BIGSERIAL PRIMARY KEY, + executor_id VARCHAR(64) NOT NULL, + source_id VARCHAR(32) NOT NULL, + measured_at TIMESTAMPTZ NOT NULL DEFAULT now(), + bytes_per_sec FLOAT NOT NULL, + sample_size BIGINT NOT NULL, + is_active_probe BOOLEAN NOT NULL DEFAULT FALSE +); +CREATE INDEX idx_speed_recent ON source_speed_samples(executor_id, source_id, measured_at DESC); +CREATE TABLE source_blacklist ( + id BIGSERIAL PRIMARY KEY, + source_id VARCHAR(32) NOT NULL, + repo_id VARCHAR(256), + filename VARCHAR(512), + until TIMESTAMPTZ NOT NULL, + reason VARCHAR(64) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() +); +CREATE INDEX idx_blacklist_lookup ON source_blacklist(source_id, repo_id, until); +``` +All additive (existing rows get defaults); clean downgrade (drop new tables/columns, reverse order). Models registered in `src/dlw/db/models/__init__.py` (SP1 lesson — never `base.py`). These are quota/infra-style tables; Invariant 8 (tenant scoping) applies to *business* rows — `subtask_chunks` inherits scoping via its `file_subtasks`→`download_tasks.tenant_id` FK chain (queries go through `tenant_filtered` on the parent task, consistent with SP1); `source_speed_samples`/`source_blacklist` are operational, not tenant data (documented, like SP1's `casbin_rule`). + +## 6. Wire Format Changes + +- **New:** `GET /api/v1/source-proxy/subtask/{subtask_id}` (+ optional `Range`) — executor-auth (W3b chain), 200 stream / 403 `NOT_YOUR_SUBTASK` / 409 `STALE_ASSIGNMENT`|`EPOCH_MISMATCH` / 503 source unreachable / 502 `SOURCE_BLACKLISTED` (re-plan triggered). `/api/v1/hf-proxy/...` retained. +- **Changed:** `POST /api/v1/tasks` body (`TaskCreate`) gains optional `source_strategy` (default `auto_balance`; `pin_huggingface`|`pin_modelscope`|`list:...`|`fastest_only`), `source_blacklist: list[str]`, `trust_non_hf_sha256: bool`. `GET /api/v1/tasks/{id}` (`TaskDetail`) exposes `source_strategy` + per-subtask `source_id` + chunk summary. Task status domain gains `scheduling`. +- **Config:** the §3.9 `Settings` fields (all `DLW_`-prefixed). New files `config/sources.yaml`, `config/resolver-rules.yaml`. +- **OpenAPI:** `api/openapi.yaml` gains the source-proxy op + the new `TaskCreate`/`TaskDetail` fields + `scheduling` status enum value; must pass spectral `--fail-severity=error` + swagger-cli + yamllint(`api/`). + +## 7. Error Handling Matrix + +| Situation | Behaviour | +|---|---| +| HF unreachable, `trust_non_hf_sha256=false` | task → `paused_external`, `last_error="no_sha256_authority"` (INVARIANT 13); 5min retry | +| HF unreachable, `trust_non_hf_sha256=true` | proceed using other sources; no sha256 verification (flag honored; approval workflow Phase 4) | +| A source doesn't cover repo (`resolve→None`) / NameResolver miss | that source skipped; not fatal; logged | +| hf_mirror hits gated repo (401/403 on resolve) | hf_mirror auto-skipped (doc §1.9.2) | +| All sources probe = 0 | task → `paused_external`, 5min re-probe | +| Only 1 healthy source | single-source full-speed (v1.x behavior; no chunk split) | +| 5xx ×3 on `(executor,source)` | degraded 5min (exp→`blacklist_max_minutes`); planner avoids it | +| `health_check` > 30s | source globally degraded → fall back to HF until next OK probe | +| sha256 mismatch on completed file | `(source_id,repo_id,filename)` 24h blacklist; that file re-planned HF-only; subtask re-queued | +| Non-HF source, HF has no LFS sha256 for file | that file = single-source HF, no acceleration; UI-flag (doc §1.13 risk 1) | +| Chunked file, one chunk's source blacklisted mid-flight | rebalance loop reassigns that chunk's `pending` siblings; the failed chunk re-queued to a healthy source | +| Chunked file complete | whole-file SHA256 re-scan vs HF authority (multi-source can't stream-hash, 03 §6); mismatch → blacklist+re-plan | +| `source_strategy=pin_modelscope` but ModelScope down | task `paused_external` (respects explicit pin; no silent HF fallback) | +| Standby instance | rebalance loop not running (leader-gated); planner runs only on the active (scheduling happens on active) | +| Single-source-HF deployment (only `huggingface` enabled) | planner degrades to one source; identical to pre-SP2 behavior | + +## 8. Testing Strategy + +TDD throughout. Unit tests use `httpx.MockTransport` (driver HTTP) and fake `SourceDriver`s (planner/proxy) — no live mirrors. + +| Area | File | Cases | +|---|---|---| +| Protocol/dataclasses | `tests/sources/test_base.py` | manifest/file normalization, SourceToken never-serialized | +| HF driver | `tests/sources/test_hf_driver.py` | resolve via mocked `list_repo_tree` / download_range Range / provides_sha256 | +| hf_mirror driver | `tests/sources/test_hf_mirror_driver.py` | base-url swap / no token / gated→None | +| ModelScope driver | `tests/sources/test_modelscope_driver.py` | API shape / name-mapped repo / sha256=None / download Range | +| Registry | `tests/sources/test_registry.py` | sources.yaml parse / only-enabled / unknown driver ignored / regional_defaults | +| NameResolver | `tests/sources/test_name_resolver.py` | identity / alias transform (`Meta-{name}`) / per-model override / search fallback+24h cache / miss→None | +| Speed | `tests/services/test_source_speed.py` | EWMA fusion / probe soft-timeout / all-zero handling | +| LPT | `tests/services/test_lpt.py` | size-desc greedy / U-SRC-005 many-small / U-SRC-006 few-large / single-source degen | +| Combo | `tests/services/test_combo.py` | slow source excluded by overhead penalty / 1-source / monotone-stop | +| Planner | `tests/services/test_source_scheduler.py` | resolve→probe→assign persists source_id/subtask_chunks / HF-absent+!trust→paused (INV 13) / chunk split ≥100MB / non-HF+no-sha→single-source | +| Blacklist | `tests/services/test_source_blacklist.py` | 5xx×3 degrade+expiry / sha-mismatch 24h / health-timeout global | +| Proxy | `tests/api/test_source_proxy.py` | routes to assigned driver / cred injected / INVARIANT-2 (no source token in any executor-bound payload) / ownership chain / blacklisted→502 | +| Sha authority | `tests/services/test_sha_authority.py` | non-HF verified vs HF / mismatch→blacklist+HF-refetch / chunked→whole-file rescan | +| Migration | `tests/db/test_p3sp2_migration.py` | up creates tables+cols / down clean / additive defaults | +| Rebalance | `tests/services/test_rebalance.py` | degraded source's pending chunks reassigned (SKIP LOCKED) / in-flight untouched / leader-gated | +| Lifespan | `tests/test_lifespan_state.py` (extend) | `app.state.source_registry`/`name_resolver` set by real lifespan (SP1 regression-class) | +| E2E-002 | `tests/e2e/test_multi_source.py` | auto_balance picks fastest, skips slow source; 5xx failover; sha-mismatch blacklists; HF-down pauses (mirrors doc §8) | + +**Test infra (SP1 lessons, mandatory):** every DB test fixture uses `drop_all → create_all` (clean slate, session-DB collision avoidance) and is function-scoped where state must not bleed; `from dlw.db import models` import where `create_all` is used; new test dirs (`tests/sources/`) get `__init__.py`; the `make_app_with_state` conftest helper extended to also seed `app.state.source_registry`/`name_resolver` (so ASGI tests don't 500 — the SP1 CRITICAL-class pitfall), AND a `test_lifespan_state` assertion for the real lifespan. Subagent "pre-existing/passes-in-isolation" claims must be controller-verified. + +## 9. Acceptance Criteria + +- [ ] `SourceDriver` Protocol + 3 drivers (HF refactor / hf_mirror / ModelScope) with mocked-transport unit tests; `provides_sha256` correct per source. +- [ ] `NameResolver` 3-tier from `resolver-rules.yaml`; identity/alias/search+cache; miss→skip. +- [ ] Registry from `sources.yaml`; only enabled+supported drivers; bootstrapped in **lifespan** (+ `make_app_with_state` + `test_lifespan_state`). +- [ ] Scheduling-phase planner: resolve→probe→`_solve_optimal_combination`→LPT files + chunk-split ≥`chunk_level_min_file_mb`; persists `source_id`/`subtask_chunks`. +- [ ] HF sha256 authority enforced (INVARIANT 11/12/13): non-HF verified vs HF; HF-down→`paused_external` unless `trust_non_hf_sha256`; sha-mismatch→24h `(source,repo,filename)` blacklist + HF-refetch. +- [ ] `/api/v1/source-proxy` streams per assigned source, controller-side cred, INVARIANT 2 preserved; `/hf-proxy` retained. +- [ ] Chunk mode → whole-file SHA256 re-scan post-merge (03 §6). +- [ ] Minimal leader-gated rebalance: degraded source's pending chunks reassigned via SKIP LOCKED. +- [ ] One additive alembic migration (down_revision `a4bed702cdb3`); clean up/down; models in `db/models/__init__.py`; `scheduling`/chunk statuses accepted by `tools/lint_invariants.py`. +- [ ] `TaskCreate` gains `source_strategy`/`source_blacklist`/`trust_non_hf_sha256`; tenant-scoped via SP1 `require_perm`/`tenant_filtered` (no new RBAC). +- [ ] Full suite green; `invariant_lint`/`openapi`(spectral+swagger-cli)/`yamllint` CI gates green; `pyyaml` in `pyproject.toml` + `uv.lock` committed. +- [ ] `E2E-002` (`tests/e2e/test_multi_source.py`) passes; operator note written. + +## 10. Implementation Phasing (preview for plan) + +5 milestones, ~16–18 TDD tasks. + +- **M1 — Source layer.** `base.py` Protocol+dataclasses; HF/hf_mirror/ModelScope drivers; registry + `sources.yaml`; NameResolver + `resolver-rules.yaml`; config fields; `pyyaml` dep. Pure unit (mock transport). +- **M2 — Schema + models.** migration (cols + 3 tables) + `db/models/source.py` + `__init__.py` reg + `lint_invariants` status additions + migration test. +- **M3 — Planner.** `source_speed` (probe+EWMA), `_solve_optimal_combination`, `assign_files_lpt`, chunk-split, `source_scheduler.plan_task_sources` + HF-authority gate; scheduling-phase wiring in `task_service`/`scheduler`; blacklist service. +- **M4 — Proxy + executor + lifespan.** `source_proxy.py`; `chunk_downloader`/`client` → `/source-proxy` per-chunk range; sha-authority verify (incl. chunked whole-file rescan); lifespan registry/resolver bootstrap + leader-gated `_scheduling_loop` + `_rebalance_loop`; `make_app_with_state` + `test_lifespan_state` extension. +- **M5 — E2E + docs + PR.** `tests/e2e/test_multi_source.py` (E2E-002), OpenAPI updates, `docs/operator/multi-source.md`, full suite + all CI gates, final whole-impl security/correctness review, PR, squash-merge. + +Branch: `feat/phase-3-sp2-multi-source` (off `main` @ `fa08e6d`). + +## 11. References + +- Design: `docs/v2.0/06-platform-and-ecosystem.md` §1 (authoritative), §8 (E2E). +- Roadmap: `docs/v2.0/08-mvp-roadmap.md` §3 W2 + §3.5 exit. +- Invariants: `docs/v2.0/INVARIANTS.md` 11/12/13 (HF authority/blacklist/refuse), 8 (tenant scope, SP1). +- Current-state anchors: `src/dlw/services/hf_metadata.py` (`list_repo_tree`/`RepoFile`), `src/dlw/api/hf_proxy.py` (W3b proxy, ownership chain to copy), `src/dlw/executor/chunk_downloader.py` (range download), `src/dlw/services/scheduler.py`/`task_service.py` (status machine; `tools/lint_invariants.py` status domains), alembic head `a4bed702cdb3`. +- Predecessor: `docs/superpowers/specs/2026-05-18-phase-3-sp1-multi-tenancy-design.md` (per-tenant `task.tenant_id`, `Principal`/`require_perm`/`tenant_filtered`, lifespan-state lesson, leader-gated-loop pattern reused for rebalance). +- SP1 merged: https://github.com/l17728/modelpull/pull/15 (squash `fa08e6d`). diff --git a/pyproject.toml b/pyproject.toml index 39eadcb..4edaa22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "casbin>=1.36,<2.0", "pydantic>=2.9,<2.11", "pydantic-settings>=2.6,<2.7", + "pyyaml>=6,<7", "structlog>=24.4,<24.5", "httpx>=0.27,<0.28", "tenacity>=9.0,<10.0", diff --git a/src/dlw/alembic/versions/bb1dd2c45a12_p3sp2_multi_source.py b/src/dlw/alembic/versions/bb1dd2c45a12_p3sp2_multi_source.py new file mode 100644 index 0000000..74d5d4b --- /dev/null +++ b/src/dlw/alembic/versions/bb1dd2c45a12_p3sp2_multi_source.py @@ -0,0 +1,95 @@ +"""p3sp2 multi source + +Revision ID: bb1dd2c45a12 +Revises: a4bed702cdb3 +Create Date: 2026-05-19 01:22:07.324751 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = 'bb1dd2c45a12' +down_revision: Union[str, None] = 'a4bed702cdb3' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("download_tasks", sa.Column( + "source_strategy", sa.String(32), nullable=False, + server_default="auto_balance")) + op.add_column("download_tasks", sa.Column( + "source_blacklist", postgresql.JSONB(), nullable=False, + server_default="[]")) + op.add_column("download_tasks", sa.Column( + "trust_non_hf_sha256", sa.Boolean(), nullable=False, + server_default=sa.false())) + op.add_column("file_subtasks", sa.Column( + "source_id", sa.String(32), nullable=True)) + op.add_column("file_subtasks", sa.Column( + "is_chunked", sa.Boolean(), nullable=False, + server_default=sa.false())) + op.create_table( + "subtask_chunks", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("subtask_id", postgresql.UUID(as_uuid=True), + sa.ForeignKey("file_subtasks.id", ondelete="CASCADE"), + nullable=False), + sa.Column("chunk_index", sa.Integer(), nullable=False), + sa.Column("byte_start", sa.BigInteger(), nullable=False), + sa.Column("byte_end", sa.BigInteger(), nullable=False), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("status", sa.String(16), nullable=False), + sa.Column("sha256_partial", sa.String(64), nullable=True), + sa.Column("bytes_done", sa.BigInteger(), nullable=False, + server_default="0"), + sa.UniqueConstraint("subtask_id", "chunk_index"), + ) + op.create_index("idx_chunk_sub_status", "subtask_chunks", + ["subtask_id", "status"]) + op.create_table( + "source_speed_samples", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("executor_id", sa.String(64), nullable=False), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("measured_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + sa.Column("bytes_per_sec", sa.Float(), nullable=False), + sa.Column("sample_size", sa.BigInteger(), nullable=False), + sa.Column("is_active_probe", sa.Boolean(), nullable=False, + server_default=sa.false()), + ) + op.create_index("idx_speed_recent", "source_speed_samples", + ["executor_id", "source_id", "measured_at"]) + op.create_table( + "source_blacklist", + sa.Column("id", sa.BigInteger(), primary_key=True), + sa.Column("source_id", sa.String(32), nullable=False), + sa.Column("repo_id", sa.String(256), nullable=True), + sa.Column("filename", sa.String(512), nullable=True), + sa.Column("until", sa.DateTime(timezone=True), nullable=False), + sa.Column("reason", sa.String(64), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), + server_default=sa.func.now(), nullable=False), + ) + op.create_index("idx_blacklist_lookup", "source_blacklist", + ["source_id", "repo_id", "until"]) + + +def downgrade() -> None: + op.drop_index("idx_blacklist_lookup", "source_blacklist") + op.drop_table("source_blacklist") + op.drop_index("idx_speed_recent", "source_speed_samples") + op.drop_table("source_speed_samples") + op.drop_index("idx_chunk_sub_status", "subtask_chunks") + op.drop_table("subtask_chunks") + op.drop_column("file_subtasks", "is_chunked") + op.drop_column("file_subtasks", "source_id") + op.drop_column("download_tasks", "trust_non_hf_sha256") + op.drop_column("download_tasks", "source_blacklist") + op.drop_column("download_tasks", "source_strategy") diff --git a/src/dlw/api/source_proxy.py b/src/dlw/api/source_proxy.py new file mode 100644 index 0000000..e9b65b1 --- /dev/null +++ b/src/dlw/api/source_proxy.py @@ -0,0 +1,130 @@ +"""Generalized multi-source reverse-proxy (Phase 3 SP2). Mirrors the W3b +hf_proxy ownership chain; routes each subtask/chunk to its assigned +SourceDriver and injects that source's controller-side credential. The +source token NEVER leaves the controller (INVARIANT 2).""" +from __future__ import annotations + +import uuid + +import httpx +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi.responses import StreamingResponse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.api.tasks import _session +from dlw.auth.executor_jwt_dep import require_executor_jwt +from dlw.config import get_settings +from dlw.db.models.executor import Executor +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.sources.base import SourceFile + +router = APIRouter(prefix="/api/v1/source-proxy", tags=["executors"]) + +_HDR_ALLOW = frozenset({ + "content-length", "content-range", "content-type", + "accept-ranges", "etag", +}) + + +def _make_source_client(timeout_seconds: int) -> httpx.AsyncClient: + """Test seam — monkeypatched to inject httpx.MockTransport.""" + return httpx.AsyncClient(follow_redirects=True, timeout=timeout_seconds) + + +@router.get("/subtask/{subtask_id}") +async def source_proxy_subtask( + subtask_id: uuid.UUID, + request: Request, + x_assignment_token: str = Header(..., alias="X-Assignment-Token"), + auth_ex: Executor = Depends(require_executor_jwt), + session: AsyncSession = Depends(_session), +) -> StreamingResponse: + sub = await session.get(FileSubTask, subtask_id) + if sub is None: + raise HTTPException(status_code=404, detail="subtask not found") + if sub.executor_id != auth_ex.id: + raise HTTPException( + status_code=403, + detail={"code": "NOT_YOUR_SUBTASK", + "subtask_executor": sub.executor_id, + "authenticated": auth_ex.id}, + ) + if sub.assignment_token is None or str(sub.assignment_token) != x_assignment_token: + raise HTTPException( + status_code=409, detail={"code": "STALE_ASSIGNMENT"}, + ) + if sub.executor_epoch != auth_ex.epoch: + raise HTTPException( + status_code=409, + detail={"code": "EPOCH_MISMATCH", + "expected": sub.executor_epoch, "got": auth_ex.epoch}, + ) + + task = await session.get(DownloadTask, sub.task_id) + if task is None: + raise HTTPException(status_code=500, detail="parent task missing") + + settings = get_settings() + range_header = request.headers.get("Range") + + source_id = sub.source_id + if sub.is_chunked and range_header and range_header.startswith("bytes="): + start = int(range_header.split("=", 1)[1].split("-", 1)[0]) + chunk = await session.scalar(select(SubtaskChunk).where( + SubtaskChunk.subtask_id == sub.id, + SubtaskChunk.byte_start <= start, + SubtaskChunk.byte_end >= start)) + if chunk is not None: + source_id = chunk.source_id + # Back-compat: a subtask that never went through SP2 source scheduling + # (legacy / single-source / ASGI-e2e paths) has source_id=None — default + # to the huggingface driver so /source-proxy is a strict superset of the + # old /hf-proxy behaviour (spec §1.1: single-source-HF still works). + if source_id is None: + source_id = "huggingface" + + registry = request.app.state.source_registry + drv = registry.get(source_id) + if drv is None: + raise HTTPException(status_code=502, detail={"code": "SOURCE_UNAVAILABLE"}) + + src_file = SourceFile(filename=sub.filename, size=sub.file_size, + sha256=sub.expected_sha256, + download_ref=f"{task.repo_id}/resolve/" + f"{task.revision}/{sub.filename}") + url = drv.download_url(src_file) + tok = drv.auth_token(settings.hf_token) + headers: dict[str, str] = {} + if tok.scheme == "bearer" and tok.value: + headers["Authorization"] = f"Bearer {tok.value}" + if range_header: + headers["Range"] = range_header + + client = _make_source_client(settings.hf_proxy_timeout_seconds) + req = client.build_request("GET", url, headers=headers) + try: + resp = await client.send(req, stream=True) + except (httpx.TimeoutException, httpx.NetworkError) as e: + await client.aclose() + raise HTTPException( + status_code=503, detail=f"source unreachable: {e}", + ) from e + except BaseException: + await client.aclose() + raise + + fwd = {k: v for k, v in resp.headers.items() + if k.lower() in _HDR_ALLOW} + + async def _body(): + try: + async for chunk in resp.aiter_bytes(64 * 1024): + yield chunk + finally: + await resp.aclose() + await client.aclose() + + return StreamingResponse(_body(), status_code=resp.status_code, + headers=fwd) diff --git a/src/dlw/config.py b/src/dlw/config.py index 452378d..81c9901 100644 --- a/src/dlw/config.py +++ b/src/dlw/config.py @@ -56,6 +56,22 @@ class Settings(BaseSettings): ) auth_tenant_rules_json: str = Field(default="[]") + # Phase 3 SP2 — multi-source + sources_yaml_path: str = Field(default="config/sources.yaml") + resolver_rules_path: str = Field(default="config/resolver-rules.yaml") + probe_size_mb: int = Field(default=32, ge=1, le=256) + probe_timeout_s: float = Field(default=8.0, ge=1.0, le=60.0) + probe_history_weight: float = Field(default=0.3, ge=0.0, le=1.0) + combo_overhead_per_source_pct: float = Field(default=2.0, ge=0.0, le=50.0) + chunk_level_min_file_mb: int = Field(default=100, ge=1) + speed_ewma_alpha: float = Field(default=0.3, ge=0.0, le=1.0) + blacklist_5xx_count: int = Field(default=3, ge=1) + blacklist_minutes: int = Field(default=5, ge=1) + blacklist_max_minutes: int = Field(default=30, ge=1) + sha_mismatch_blacklist_hours: int = Field(default=24, ge=1) + rebalance_interval_seconds: float = Field(default=60.0, ge=5.0, le=600.0) + degradation_trigger_threshold: float = Field(default=0.3, ge=0.0, le=1.0) + @property def db_url(self) -> str: auth = f"{self.db_user}:{self.db_password}" if self.db_password else self.db_user diff --git a/src/dlw/db/models/__init__.py b/src/dlw/db/models/__init__.py index 2ae0618..1f7a432 100644 --- a/src/dlw/db/models/__init__.py +++ b/src/dlw/db/models/__init__.py @@ -4,6 +4,7 @@ from dlw.db.models.casbin_rule import CasbinRule from dlw.db.models.executor import Executor from dlw.db.models.executor_status_history import ExecutorStatusHistory +from dlw.db.models.source import SourceBlacklist, SourceSpeedSample, SubtaskChunk from dlw.db.models.storage import StorageBackend from dlw.db.models.task import DownloadTask, FileSubTask from dlw.db.models.tenant import Project, Tenant, User @@ -11,5 +12,6 @@ __all__ = [ "AuditLog", "CasbinRule", "DownloadTask", "Executor", "ExecutorStatusHistory", - "FileSubTask", "Project", "QuotaSnapshot", "StorageBackend", "Tenant", "UsageRecord", "User", + "FileSubTask", "Project", "QuotaSnapshot", "SourceBlacklist", "SourceSpeedSample", + "StorageBackend", "SubtaskChunk", "Tenant", "UsageRecord", "User", ] diff --git a/src/dlw/db/models/source.py b/src/dlw/db/models/source.py new file mode 100644 index 0000000..638d6eb --- /dev/null +++ b/src/dlw/db/models/source.py @@ -0,0 +1,66 @@ +"""Multi-source models (Phase 3 SP2; doc 06 §1.4/§1.7).""" +from __future__ import annotations + +import uuid +from datetime import datetime + +from sqlalchemy import ( + BigInteger, + DateTime, + Float, + ForeignKey, + Integer, + String, + UniqueConstraint, + func, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from dlw.db.base import Base + + +class SubtaskChunk(Base): + __tablename__ = "subtask_chunks" + __table_args__ = (UniqueConstraint("subtask_id", "chunk_index"),) + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + subtask_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("file_subtasks.id", ondelete="CASCADE"), nullable=False) + chunk_index: Mapped[int] = mapped_column(Integer, nullable=False) + byte_start: Mapped[int] = mapped_column(BigInteger, nullable=False) + byte_end: Mapped[int] = mapped_column(BigInteger, nullable=False) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + status: Mapped[str] = mapped_column(String(16), nullable=False) + sha256_partial: Mapped[str | None] = mapped_column(String(64), nullable=True) + bytes_done: Mapped[int] = mapped_column(BigInteger, default=0, + nullable=False) + + +class SourceSpeedSample(Base): + __tablename__ = "source_speed_samples" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + executor_id: Mapped[str] = mapped_column(String(64), nullable=False) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + measured_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False) + bytes_per_sec: Mapped[float] = mapped_column(Float, nullable=False) + sample_size: Mapped[int] = mapped_column(BigInteger, nullable=False) + is_active_probe: Mapped[bool] = mapped_column(default=False, + nullable=False) + + +class SourceBlacklist(Base): + __tablename__ = "source_blacklist" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + source_id: Mapped[str] = mapped_column(String(32), nullable=False) + repo_id: Mapped[str | None] = mapped_column(String(256), nullable=True) + filename: Mapped[str | None] = mapped_column(String(512), nullable=True) + until: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False) + reason: Mapped[str] = mapped_column(String(64), nullable=False) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), nullable=False) diff --git a/src/dlw/db/models/task.py b/src/dlw/db/models/task.py index d249c31..e7b7d85 100644 --- a/src/dlw/db/models/task.py +++ b/src/dlw/db/models/task.py @@ -20,7 +20,7 @@ UniqueConstraint, func, ) -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column, relationship from dlw.db.base import Base @@ -55,13 +55,18 @@ class DownloadTask(Base): error_message: Mapped[str | None] = mapped_column(Text, nullable=True) trace_id: Mapped[str | None] = mapped_column(String(32), nullable=True) + # Phase 3 SP2: multi-source columns + source_strategy: Mapped[str] = mapped_column(String(32), default="auto_balance", nullable=False) + source_blacklist: Mapped[list] = mapped_column(JSONB, default=list, nullable=False) + trust_non_hf_sha256: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + # ORM relationship — Phase 1 Week 3 UI scaffold consumes via # selectinload(DownloadTask.subtasks) in api/tasks.get_task. # Cascade is intentionally narrow: the FK already has ondelete=CASCADE, # so DB-level cleanup handles deletion. Adding ORM-level "delete-orphan" # would risk scheduling orphan deletes if any code path triggers a lazy # load of an empty in-memory subtasks collection on a flushed parent. - subtasks: Mapped[list["FileSubTask"]] = relationship( + subtasks: Mapped[list[FileSubTask]] = relationship( "FileSubTask", back_populates="task", cascade="save-update, merge", @@ -121,12 +126,16 @@ class FileSubTask(Base): retry_count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) last_error: Mapped[str | None] = mapped_column(Text, nullable=True) + # Phase 3 SP2: multi-source columns + source_id: Mapped[str | None] = mapped_column(String(32), nullable=True) + is_chunked: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) - task: Mapped["DownloadTask"] = relationship( + task: Mapped[DownloadTask] = relationship( "DownloadTask", back_populates="subtasks", lazy="select", diff --git a/src/dlw/executor/chunk_downloader.py b/src/dlw/executor/chunk_downloader.py index ab00608..6e67f06 100644 --- a/src/dlw/executor/chunk_downloader.py +++ b/src/dlw/executor/chunk_downloader.py @@ -71,7 +71,7 @@ async def _resolve_size(self, a: Assignment) -> Assignment: """W3b: the proxy is GET-only, so probe size with a bytes=0-0 range request and read Content-Range (`bytes 0-0/`). Fall back to Content-Length if HF answered a full 200 instead of a 206.""" - async with self._controller.stream_hf( + async with self._controller.stream_source( subtask_id=a.subtask_id, assignment_token=a.assignment_token, range_header="bytes=0-0", @@ -125,7 +125,7 @@ async def _download_one_chunk( self, a: Assignment, plan: ChunkPlan, dest_dir: Path, ) -> None: range_header = f"bytes={plan.offset}-{plan.offset + plan.length - 1}" - async with self._controller.stream_hf( + async with self._controller.stream_source( subtask_id=a.subtask_id, assignment_token=a.assignment_token, range_header=range_header, diff --git a/src/dlw/executor/client.py b/src/dlw/executor/client.py index 6875e04..648c539 100644 --- a/src/dlw/executor/client.py +++ b/src/dlw/executor/client.py @@ -30,7 +30,6 @@ from dlw.auth.hmac_nonce import compute_hmac from dlw.executor.auth_lifecycle import AuthState - _retry = retry( retry=retry_if_exception_type( (httpx.HTTPStatusError, httpx.TimeoutException, httpx.NetworkError) @@ -236,3 +235,28 @@ async def stream_hf( headers=headers, ) as resp: yield resp + + @asynccontextmanager + async def stream_source( + self, + *, + subtask_id: uuid.UUID, + assignment_token: uuid.UUID, + range_header: str | None = None, + ) -> AsyncIterator[httpx.Response]: + """SP2: stream a file/chunk from its assigned source via the + controller's generalized reverse-proxy. Same contract as stream_hf + (caller inspects resp.status_code; no raise_for_status).""" + headers = { + **self._auth_headers(), + "X-Assignment-Token": str(assignment_token), + } + if range_header: + headers["Range"] = range_header + async with self._make_client() as client: + async with client.stream( + "GET", + f"/api/v1/source-proxy/subtask/{subtask_id}", + headers=headers, + ) as resp: + yield resp diff --git a/src/dlw/executor/downloader.py b/src/dlw/executor/downloader.py index 7ef913b..f040cbc 100644 --- a/src/dlw/executor/downloader.py +++ b/src/dlw/executor/downloader.py @@ -19,8 +19,12 @@ from dlw.executor._io import ( _HTTP_CHUNK_BYTES, _TRANSIENT_RETRY, - compose_key as _compose_key_io, make_s3_client, +) +from dlw.executor._io import ( + compose_key as _compose_key_io, +) +from dlw.executor._io import ( upload_part as _upload_part_io, ) from dlw.executor.client import ControllerClient @@ -68,7 +72,7 @@ async def _download_once(self, *, assignment: Assignment) -> DownloadResult: part_no = 1 try: - async with self._controller.stream_hf( + async with self._controller.stream_source( subtask_id=assignment.subtask_id, assignment_token=assignment.assignment_token, ) as resp: diff --git a/src/dlw/main.py b/src/dlw/main.py index 865a28d..7a94791 100644 --- a/src/dlw/main.py +++ b/src/dlw/main.py @@ -90,6 +90,13 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: _grants = await load_grants(_cas_session) app.state.casbin = build_enforcer(grants=_grants) + from dlw.sources.name_resolver import NameResolver + from dlw.sources.registry import load_registry + app.state.source_registry = load_registry( + _settings.sources_yaml_path, hf_token=_settings.hf_token) + app.state.name_resolver = NameResolver.from_file( + _settings.resolver_rules_path) + # W3c: controller state + leader loop. app.state.controller_state = "standby" shutdown = asyncio.Event() @@ -98,6 +105,8 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: ) sweep_task_holder: dict[str, asyncio.Task | None] = {"t": None} quota_task_holder: dict[str, asyncio.Task | None] = {"t": None} + sched_task_holder: dict[str, asyncio.Task | None] = {"t": None} + rebalance_task_holder: dict[str, asyncio.Task | None] = {"t": None} async def _quota_loop() -> None: from dlw.services.quota import aggregate_snapshots @@ -112,6 +121,34 @@ async def _quota_loop() -> None: except Exception: logger.exception("quota aggregator tick failed; retrying") + async def _scheduling_loop() -> None: + from dlw.services.source_scheduler import run_scheduling_tick + while True: + try: + await asyncio.sleep(5) + async with factory() as session: + await run_scheduling_tick( + session, app.state.source_registry, + app.state.name_resolver, _gs()) + await session.commit() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("scheduling tick failed; retrying") + + async def _rebalance_loop() -> None: + from dlw.services.source_scheduler import run_rebalance_tick + while True: + try: + await asyncio.sleep(_gs().rebalance_interval_seconds) + async with factory() as session: + await run_rebalance_tick(session, _gs()) + await session.commit() + except asyncio.CancelledError: + raise + except Exception: + logger.exception("rebalance tick failed; retrying") + def _set_state(s: str) -> None: app.state.controller_state = s @@ -124,6 +161,8 @@ async def _on_promote() -> None: async def _on_active() -> None: sweep_task_holder["t"] = asyncio.create_task(_sweep_loop_main(factory)) quota_task_holder["t"] = asyncio.create_task(_quota_loop()) + sched_task_holder["t"] = asyncio.create_task(_scheduling_loop()) + rebalance_task_holder["t"] = asyncio.create_task(_rebalance_loop()) async def _on_step_down() -> None: t = sweep_task_holder["t"] @@ -142,6 +181,22 @@ async def _on_step_down() -> None: except (TimeoutError, asyncio.CancelledError): pass quota_task_holder["t"] = None + t = sched_task_holder["t"] + if t is not None: + t.cancel() + try: + await asyncio.wait_for(t, timeout=2) + except (TimeoutError, asyncio.CancelledError): + pass + sched_task_holder["t"] = None + t = rebalance_task_holder["t"] + if t is not None: + t.cancel() + try: + await asyncio.wait_for(t, timeout=2) + except (TimeoutError, asyncio.CancelledError): + pass + rebalance_task_holder["t"] = None leader_task = asyncio.create_task(run_leader_loop( elector=elector, @@ -213,6 +268,8 @@ def create_app() -> FastAPI: app.include_router(hf_proxy_router) from dlw.api.quota import router as quota_router app.include_router(quota_router) + from dlw.api.source_proxy import router as source_proxy_router + app.include_router(source_proxy_router) return app diff --git a/src/dlw/schemas/task.py b/src/dlw/schemas/task.py index f52de1a..39b9a93 100644 --- a/src/dlw/schemas/task.py +++ b/src/dlw/schemas/task.py @@ -16,6 +16,9 @@ class TaskCreate(BaseModel): storage_id: int = Field(gt=0) path_template: str = Field(default="{tenant}/{repo_id}/{revision}", max_length=512) priority: int = Field(default=1, ge=0, le=10) + source_strategy: str = Field(default="auto_balance", max_length=32) + source_blacklist: list[str] = Field(default_factory=list) + trust_non_hf_sha256: bool = Field(default=False) class TaskRead(BaseModel): diff --git a/src/dlw/services/scheduler.py b/src/dlw/services/scheduler.py index 36b6747..471acc7 100644 --- a/src/dlw/services/scheduler.py +++ b/src/dlw/services/scheduler.py @@ -20,6 +20,7 @@ from dlw.db.models.executor import Executor from dlw.db.models.task import DownloadTask, FileSubTask from dlw.services.quota import record_usage +from dlw.services.source_blacklist import blacklist_file async def claim_one_subtask( @@ -199,6 +200,13 @@ async def complete_subtask( parent = await session.get( DownloadTask, sub.task_id, with_for_update=True ) + if (final_status == "failed" and sub.source_id + and sub.source_id != "huggingface" + and sub.expected_sha256 is not None + and actual_sha256 != sub.expected_sha256): + await blacklist_file( + session, source_id=sub.source_id, repo_id=parent.repo_id, + filename=sub.filename, hours=24, reason="sha_mismatch") siblings = (await session.execute( select(FileSubTask).where(FileSubTask.task_id == sub.task_id) )).scalars().all() diff --git a/src/dlw/services/source_blacklist.py b/src/dlw/services/source_blacklist.py new file mode 100644 index 0000000..33e2fa2 --- /dev/null +++ b/src/dlw/services/source_blacklist.py @@ -0,0 +1,37 @@ +"""Source/(source,repo,file) blacklist (Phase 3 SP2; doc §1.7). +Caller commits (service-layer convention).""" +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.db.models.source import SourceBlacklist + + +async def blacklist_file( + session: AsyncSession, *, source_id: str, repo_id: str, + filename: str, hours: int, reason: str, +) -> None: + session.add(SourceBlacklist( + source_id=source_id, repo_id=repo_id, filename=filename, + until=datetime.now(UTC) + timedelta(hours=hours), reason=reason)) + + +async def is_blacklisted( + session: AsyncSession, source_id: str, repo_id: str, filename: str +) -> bool: + row = await session.scalar( + select(SourceBlacklist.id).where( + SourceBlacklist.source_id == source_id, + SourceBlacklist.repo_id == repo_id, + SourceBlacklist.filename == filename, + SourceBlacklist.until > datetime.now(UTC)).limit(1)) + return row is not None + + +async def active_blacklisted_sources(session: AsyncSession) -> list[str]: + rows = await session.execute(select(SourceBlacklist.source_id).where( + SourceBlacklist.until > datetime.now(UTC)).distinct()) + return [r[0] for r in rows] diff --git a/src/dlw/services/source_combo.py b/src/dlw/services/source_combo.py new file mode 100644 index 0000000..84890b1 --- /dev/null +++ b/src/dlw/services/source_combo.py @@ -0,0 +1,46 @@ +"""File->source assignment: size-descending greedy heuristic (NOT bounded- +optimal LPT — doc OR-V21-04) + fastest-K combo with overhead penalty.""" +from __future__ import annotations + + +def assign_files_lpt( + files: dict[str, int], source_speeds: dict[str, float] +) -> dict[str, str]: + """files: {filename: size}; source_speeds: {source_id: bytes/sec}. + Returns {filename: source_id}. Largest-first; each file to the source + with the earliest projected completion (load+size)/speed.""" + load = {sid: 0.0 for sid in source_speeds} + out: dict[str, str] = {} + for fn in sorted(files, key=lambda k: -files[k]): + size = files[fn] + best = min(source_speeds, + key=lambda sid: (load[sid] + size) / source_speeds[sid]) + out[fn] = best + load[best] += size + return out + + +def _eta(files: dict[str, int], speeds: dict[str, float]) -> float: + assign = assign_files_lpt(files, speeds) + load = {sid: 0.0 for sid in speeds} + for fn, sid in assign.items(): + load[sid] += files[fn] + return max((load[sid] / speeds[sid] for sid in speeds), default=0.0) + + +def solve_optimal_combo( + source_speeds: dict[str, float], files: dict[str, int], + *, overhead_pct: float +) -> list[str]: + ranked = sorted(source_speeds, key=lambda s: -source_speeds[s]) + best_eta = float("inf") + best: list[str] = ranked[:1] + for k in range(1, len(ranked) + 1): + combo = ranked[:k] + sub = {s: source_speeds[s] for s in combo} + eta = _eta(files, sub) * (1 + 0.01 * overhead_pct * (k - 1)) + if eta < best_eta: + best_eta, best = eta, combo + elif k > 1 and eta > best_eta * 1.05: + break + return best diff --git a/src/dlw/services/source_scheduler.py b/src/dlw/services/source_scheduler.py new file mode 100644 index 0000000..16e3c26 --- /dev/null +++ b/src/dlw/services/source_scheduler.py @@ -0,0 +1,195 @@ +"""Task scheduling-phase source planner (Phase 3 SP2; doc §1.6/§1.8). +Caller commits.""" +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_combo import assign_files_lpt, solve_optimal_combo + +_CHUNK_BYTES = 64 * 1024 * 1024 # source-routing chunk granularity + + +def _strategy_filter(enabled: list[str], strategy: str, + blacklist: list[str]) -> tuple[list[str], str | None]: + """Apply task.source_strategy + task.source_blacklist (spec ruling 6e). + Returns (allowed_ids, pinned_or_None). pinned!=None means an explicit + single-source pin that must be honored (pause if unreachable).""" + allowed = [s for s in enabled if s not in blacklist] + if strategy == "auto_balance" or not strategy: + return allowed, None + if strategy == "fastest_only": + return allowed, None + if strategy.startswith("pin_"): + pin = strategy.removeprefix("pin_") + return ([pin] if pin in allowed else []), pin + if strategy.startswith("list:"): + wanted = [x.strip() for x in strategy.removeprefix("list:").split(",")] + return [s for s in allowed if s in wanted], None + return allowed, None + + +async def plan_task_sources( + session: AsyncSession, task: DownloadTask, *, + registry: Any, resolver: Any, speeds: dict[str, float], + chunk_min_mb: int, overhead_pct: float = 2.0, +) -> None: + allowed, pinned = _strategy_filter( + registry.enabled_ids(), task.source_strategy or "auto_balance", + list(task.source_blacklist or [])) + + manifests: dict[str, Any] = {} + for sid in allowed: + drv = registry.get(sid) + src_repo = resolver.resolve(sid, task.repo_id) + if src_repo is None: + continue + m = await drv.resolve(src_repo, task.revision) + if m is not None: + manifests[sid] = (drv, m) + + if pinned is not None and pinned not in manifests: + task.status = "paused_external" + task.error_message = "pinned_source_unavailable" + return + + hf_ok = "huggingface" in manifests + if not hf_ok and not task.trust_non_hf_sha256: + task.status = "paused_external" + task.error_message = "no_sha256_authority" + return + + candidates = {sid: speeds[sid] for sid in manifests + if sid in speeds and speeds[sid] > 0} + if not candidates: + task.status = "paused_external" + task.error_message = "no_source_speed" + return + subs = (await session.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalars().all() + sizes = {x.filename: (x.file_size or 0) for x in subs} + combo = solve_optimal_combo(candidates, sizes, overhead_pct=overhead_pct) + combo_speeds = {s: candidates[s] for s in combo} + + assign = assign_files_lpt(sizes, combo_speeds) + hf_files: set[str] = set() + if "huggingface" in manifests: + hf_files = {f.filename for f in manifests["huggingface"][1].files} + chunk_min = chunk_min_mb * 1024 * 1024 + for sub in subs: + no_hf_authority = (sub.expected_sha256 is None + or sub.filename not in hf_files) + if no_hf_authority and not task.trust_non_hf_sha256: + if "huggingface" not in manifests: + task.status = "paused_external" + task.error_message = "no_sha256_authority" + return + sub.source_id = "huggingface" + continue + sid = assign[sub.filename] + sub.source_id = sid + covering = [s for s in candidates + if any(f.filename == sub.filename + for f in manifests[s][1].files)] + if (sub.file_size or 0) >= chunk_min and len(covering) >= 2: + sub.is_chunked = True + cov_speeds = {s: candidates[s] for s in covering} + await _split_chunks(session, sub, sub.file_size, covering, + cov_speeds) + + +async def run_scheduling_tick(session, registry, resolver, settings) -> None: + """Pick `pending` tasks; controller-side probe each source; plan; move + to claimable (spec ruling 6d).""" + from dlw.db.models.source import SourceSpeedSample + from dlw.services.source_speed import ( + fuse_ewma, + pick_probe_size_bytes, + probe_source_speed, + ) + pend = (await session.execute(select(DownloadTask).where( + DownloadTask.status == "pending").limit(20))).scalars().all() + probe_bytes = pick_probe_size_bytes(probe_size_mb=settings.probe_size_mb) + for task in pend: + task.status = "scheduling" + speeds: dict[str, float] = {} + for sid in registry.enabled_ids(): + drv = registry.get(sid) + src_repo = resolver.resolve(sid, task.repo_id) + live = 0.0 + if src_repo is not None: + try: + m = await drv.resolve(src_repo, task.revision) + except Exception: + m = None + if m is not None and m.files: + probe_f = min(m.files, key=lambda f: f.size or 1 << 62) + live = await probe_source_speed( + drv, probe_f, probe_bytes=probe_bytes, + timeout_s=settings.probe_timeout_s, + hf_token=settings.hf_token) + hist = await session.scalar( + select(SourceSpeedSample.bytes_per_sec) + .where(SourceSpeedSample.source_id == sid) + .order_by(SourceSpeedSample.measured_at.desc()).limit(1)) + fused = fuse_ewma(live=live, hist=float(hist) if hist else None, + hist_weight=settings.probe_history_weight) + if live > 0: + session.add(SourceSpeedSample( + executor_id="controller", source_id=sid, + bytes_per_sec=live, sample_size=probe_bytes, + is_active_probe=True)) + speeds[sid] = fused if fused > 0 else 0.0 + await plan_task_sources( + session, task, registry=registry, resolver=resolver, + speeds=speeds, chunk_min_mb=settings.chunk_level_min_file_mb, + overhead_pct=settings.combo_overhead_per_source_pct) + if task.status == "scheduling": + task.status = "downloading" + + +async def run_rebalance_tick(session, settings) -> None: + """Reassign a degraded (blacklisted) source's PENDING chunks to a + healthy sibling source on the same subtask (in-flight untouched).""" + from sqlalchemy import text + + from dlw.services.source_blacklist import active_blacklisted_sources + bad = await active_blacklisted_sources(session) + for src in bad: + await session.execute(text( + "UPDATE subtask_chunks c SET source_id = (" + " SELECT source_id FROM subtask_chunks d " + " WHERE d.subtask_id=c.subtask_id AND d.source_id!=:bad " + " LIMIT 1) " + "WHERE c.source_id=:bad AND c.status='pending' " + "AND EXISTS (SELECT 1 FROM subtask_chunks e " + " WHERE e.subtask_id=c.subtask_id AND e.source_id!=:bad)" + ), {"bad": src}) + + +async def _split_chunks( + session: AsyncSession, sub: FileSubTask, size: int, + sources: list[str], speeds: dict[str, float], +) -> None: + total = sum(speeds[s] for s in sources) or 1.0 + offset = 0 + idx = 0 + for i, sid in enumerate(sources): + if i == len(sources) - 1: + length = size - offset + else: + portion = int(size * speeds[sid] / total) + length = max(_CHUNK_BYTES, + (portion // _CHUNK_BYTES) * _CHUNK_BYTES) + length = min(length, size - offset) + if length <= 0: + continue + session.add(SubtaskChunk( + subtask_id=sub.id, chunk_index=idx, byte_start=offset, + byte_end=offset + length - 1, source_id=sid, status="pending")) + offset += length + idx += 1 diff --git a/src/dlw/services/source_speed.py b/src/dlw/services/source_speed.py new file mode 100644 index 0000000..d34b8ca --- /dev/null +++ b/src/dlw/services/source_speed.py @@ -0,0 +1,48 @@ +"""Source speed: controller-side probe + EWMA fusion (Phase 3 SP2).""" +from __future__ import annotations + +import time +from typing import Any + +import httpx + + +def fuse_ewma(*, live: float, hist: float | None, + hist_weight: float) -> float: + if hist is None: + return live + return (1.0 - hist_weight) * live + hist_weight * hist + + +def pick_probe_size_bytes(*, probe_size_mb: int) -> int: + return probe_size_mb * 1024 * 1024 + + +async def probe_source_speed( + driver: Any, file: Any, *, probe_bytes: int, timeout_s: float, + hf_token: str | None, + transport: httpx.AsyncBaseTransport | None = None, +) -> float: + """One ranged GET (controller->source) timing bytes/sec. 0.0 on any + failure (effect: that source is treated as unavailable for this task).""" + url = driver.download_url(file) + tok = driver.auth_token(hf_token) + headers = {"Range": f"bytes=0-{max(0, probe_bytes - 1)}"} + if tok.scheme == "bearer" and tok.value: + headers["Authorization"] = f"Bearer {tok.value}" + try: + async with httpx.AsyncClient(timeout=timeout_s, transport=transport, + follow_redirects=True) as c: + start = time.monotonic() + recv = 0 + async with c.stream("GET", url, headers=headers) as resp: + if resp.status_code >= 400: + return 0.0 + async for buf in resp.aiter_bytes(64 * 1024): + recv += len(buf) + elapsed = time.monotonic() - start + if recv <= 0: + return 0.0 + return recv / max(elapsed, 1e-9) + except Exception: + return 0.0 diff --git a/src/dlw/sources/__init__.py b/src/dlw/sources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/dlw/sources/base.py b/src/dlw/sources/base.py new file mode 100644 index 0000000..49b5c4f --- /dev/null +++ b/src/dlw/sources/base.py @@ -0,0 +1,54 @@ +"""SourceDriver abstraction (Phase 3 SP2; design doc 06 §1.3).""" +from __future__ import annotations + +from dataclasses import dataclass, field +from decimal import Decimal +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class SourceFile: + filename: str # normalized HF-style path (cross-source key) + size: int | None + sha256: str | None # only HF / hf_mirror populate this + download_ref: str # source-specific URL or object key + + +@dataclass(frozen=True) +class SourceManifest: + source_id: str + repo_id_in_source: str + revision_in_source: str + files: list[SourceFile] + has_lfs_sha256: bool + + +@dataclass(frozen=True) +class SourceHealth: + ok: bool + latency_ms: float + + +@dataclass(frozen=True) +class SourceToken: + scheme: str # "bearer" | "none" + value: str = field(default="", repr=False) # never in repr/logs (INV 2) + + +@runtime_checkable +class SourceDriver(Protocol): + id: str + domain: str + provides_sha256: bool + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: ... + + def download_url(self, file: SourceFile) -> str: ... + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: ... + + async def health_check(self) -> SourceHealth: ... + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: ... diff --git a/src/dlw/sources/hf_mirror.py b/src/dlw/sources/hf_mirror.py new file mode 100644 index 0000000..ca24198 --- /dev/null +++ b/src/dlw/sources/hf_mirror.py @@ -0,0 +1,58 @@ +"""hf-mirror.com SourceDriver — HF-compatible, no token, gated->skip (SP2).""" +from __future__ import annotations + +from decimal import Decimal + +from dlw.services.hf_metadata import ( + HfNetworkError, + HfPrivateOrAuthRequired, + RepoNotFound, + list_repo_tree, +) +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class HfMirrorDriver: + id = "hf_mirror" + domain = "hf-mirror.com" + provides_sha256 = True + + def __init__(self, *, base_url: str) -> None: + self._base = base_url.rstrip("/") + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + try: + files = await list_repo_tree( + repo_id, revision, hf_endpoint=self._base, hf_token=None) + except RepoNotFound: + return None + except HfPrivateOrAuthRequired: + return None + except HfNetworkError: + raise + sf = [SourceFile(filename=f.path, size=f.size, sha256=f.sha256, + download_ref=f"{repo_id}/resolve/{revision}/{f.path}") + for f in files] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, + has_lfs_sha256=any(f.sha256 for f in sf)) + + def download_url(self, file: SourceFile) -> str: + return f"{self._base}/{file.download_ref}" + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + return SourceToken(scheme="none") + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal(0) diff --git a/src/dlw/sources/huggingface.py b/src/dlw/sources/huggingface.py new file mode 100644 index 0000000..c4e909c --- /dev/null +++ b/src/dlw/sources/huggingface.py @@ -0,0 +1,60 @@ +"""HuggingFace SourceDriver — wraps the existing hf_metadata path (SP2).""" +from __future__ import annotations + +from decimal import Decimal + +from dlw.services.hf_metadata import ( + HfNetworkError, + HfPrivateOrAuthRequired, + RepoNotFound, + list_repo_tree, +) +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class HuggingFaceDriver: + id = "huggingface" + domain = "huggingface.co" + provides_sha256 = True + + def __init__(self, *, base_url: str, hf_token: str | None) -> None: + self._base = base_url.rstrip("/") + self._token = hf_token + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + try: + files = await list_repo_tree( + repo_id, revision, + hf_endpoint=self._base, hf_token=self._token) + except RepoNotFound: + return None + except (HfPrivateOrAuthRequired, HfNetworkError): + raise + sf = [SourceFile(filename=f.path, size=f.size, sha256=f.sha256, + download_ref=f"{repo_id}/resolve/{revision}/{f.path}") + for f in files] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, + has_lfs_sha256=any(f.sha256 for f in sf)) + + def download_url(self, file: SourceFile) -> str: + return f"{self._base}/{file.download_ref}" + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + tok = tenant_hf_token or self._token + return (SourceToken(scheme="bearer", value=tok) if tok + else SourceToken(scheme="none")) + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal("0.09") * Decimal(n_bytes) / Decimal(1_000_000_000) diff --git a/src/dlw/sources/modelscope.py b/src/dlw/sources/modelscope.py new file mode 100644 index 0000000..c113570 --- /dev/null +++ b/src/dlw/sources/modelscope.py @@ -0,0 +1,60 @@ +"""ModelScope SourceDriver — raw httpx, no official sha256 (SP2; doc §1.9.3).""" +from __future__ import annotations + +from decimal import Decimal +from urllib.parse import quote + +import httpx + +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +class ModelScopeDriver: + id = "modelscope" + domain = "modelscope.cn" + provides_sha256 = False + + def __init__(self, *, base_url: str, + transport: httpx.AsyncBaseTransport | None = None) -> None: + self._base = base_url.rstrip("/") + self._transport = transport + + def _client(self) -> httpx.AsyncClient: + return httpx.AsyncClient(timeout=30, transport=self._transport) + + async def resolve( + self, repo_id: str, revision: str + ) -> SourceManifest | None: + url = f"{self._base}/api/v1/models/{repo_id}/repo?Revision={revision}" + async with self._client() as c: + r = await c.get(url) + if r.status_code == 404: + return None + r.raise_for_status() + data = r.json().get("Data", {}).get("Files", []) + sf = [SourceFile(filename=d["Path"], size=d.get("Size"), + sha256=None, + download_ref=f"{repo_id}|{revision}|{d['Path']}") + for d in data] + return SourceManifest( + source_id=self.id, repo_id_in_source=repo_id, + revision_in_source=revision, files=sf, has_lfs_sha256=False) + + def download_url(self, file: SourceFile) -> str: + repo, rev, path = file.download_ref.split("|", 2) + return (f"{self._base}/api/v1/models/{repo}/repo" + f"?Revision={rev}&FilePath={quote(path)}") + + def auth_token(self, tenant_hf_token: str | None) -> SourceToken: + return SourceToken(scheme="none") + + async def health_check(self) -> SourceHealth: + return SourceHealth(ok=True, latency_ms=0.0) + + def estimate_cost(self, n_bytes: int, region: str) -> Decimal: + return Decimal(0) diff --git a/src/dlw/sources/name_resolver.py b/src/dlw/sources/name_resolver.py new file mode 100644 index 0000000..04efa8d --- /dev/null +++ b/src/dlw/sources/name_resolver.py @@ -0,0 +1,50 @@ +"""3-tier source name resolution (Phase 3 SP2; doc §1.5). + +Tier 1 identity (HF, or org in identity_organizations); tier 2 alias / +per-model rules from resolver-rules.yaml; tier 3 source search-API (deferred +to a stub that returns None — wiring point for v2.1; cache scaffold present).""" +from __future__ import annotations + +from dataclasses import dataclass + +import yaml + + +@dataclass +class _Alias: + hf_org: str + ms_org: str + transform: str # e.g. "Meta-{name}" + + +class NameResolver: + def __init__(self, *, identity_orgs: set[str], aliases: list[_Alias], + overrides: dict[str, str]) -> None: + self._identity = identity_orgs + self._aliases = {a.hf_org: a for a in aliases} + self._overrides = overrides # "hf_repo" -> "src_repo" + self._search_cache: dict[tuple[str, str], str] = {} + + @classmethod + def from_file(cls, path: str) -> NameResolver: + with open(path, encoding="utf-8") as fh: + doc = yaml.safe_load(fh) or {} + aliases = [_Alias(a["hf_org"], a["modelscope_org"], a["transform"]) + for a in doc.get("aliases", [])] + overrides = {o["hf"]: o["modelscope"] + for o in doc.get("per_model_overrides", [])} + return cls(identity_orgs=set(doc.get("identity_organizations", [])), + aliases=aliases, overrides=overrides) + + def resolve(self, source_id: str, hf_repo_id: str) -> str | None: + if source_id == "huggingface" or source_id == "hf_mirror": + return hf_repo_id + if hf_repo_id in self._overrides: + return self._overrides[hf_repo_id] + org, _, name = hf_repo_id.partition("/") + if org in self._identity: + return hf_repo_id + a = self._aliases.get(org) + if a is not None: + return f"{a.ms_org}/{a.transform.format(name=name)}" + return self._search_cache.get((source_id, hf_repo_id)) diff --git a/src/dlw/sources/registry.py b/src/dlw/sources/registry.py new file mode 100644 index 0000000..0383785 --- /dev/null +++ b/src/dlw/sources/registry.py @@ -0,0 +1,57 @@ +"""sources.yaml -> enabled SourceDriver registry (Phase 3 SP2).""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import yaml + +from dlw.sources.base import SourceDriver +from dlw.sources.hf_mirror import HfMirrorDriver +from dlw.sources.huggingface import HuggingFaceDriver +from dlw.sources.modelscope import ModelScopeDriver + +_SUPPORTED = {"huggingface", "hf_mirror", "modelscope"} + + +@dataclass +class SourceRegistry: + _drivers: dict[str, SourceDriver] + regional_defaults: dict[str, list[str]] = field(default_factory=dict) + + def enabled_ids(self) -> list[str]: + return list(self._drivers.keys()) + + def get(self, source_id: str) -> SourceDriver | None: + return self._drivers.get(source_id) + + +def _build(driver: str, cfg: dict[str, Any], + hf_token: str | None) -> SourceDriver | None: + if driver == "huggingface": + return HuggingFaceDriver( + base_url=cfg.get("base_url", "https://huggingface.co"), + hf_token=hf_token) + if driver == "hf_mirror": + return HfMirrorDriver( + base_url=cfg.get("base_url", "https://hf-mirror.com")) + if driver == "modelscope": + return ModelScopeDriver( + base_url=cfg.get("base_url", "https://www.modelscope.cn")) + return None + + +def load_registry(path: str, *, hf_token: str | None) -> SourceRegistry: + with open(path, encoding="utf-8") as fh: + doc = yaml.safe_load(fh) or {} + drivers: dict[str, SourceDriver] = {} + for entry in doc.get("sources", []): + if not entry.get("enabled"): + continue + if entry.get("driver") not in _SUPPORTED: + continue + d = _build(entry["driver"], entry.get("config") or {}, hf_token) + if d is not None: + drivers[entry["id"]] = d + return SourceRegistry(_drivers=drivers, + regional_defaults=doc.get("regional_defaults", {})) diff --git a/tests/api/test_source_proxy.py b/tests/api/test_source_proxy.py new file mode 100644 index 0000000..6d51f9d --- /dev/null +++ b/tests/api/test_source_proxy.py @@ -0,0 +1,104 @@ +"""source-proxy routes to the assigned driver, INVARIANT 2 (SP2).""" +from __future__ import annotations + +import uuid + +import httpx +import pytest +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from tests.conftest import ( + executor_request_headers, + make_app_with_state, + register_test_executor, +) + +pytestmark = pytest.mark.slow + +SECRET = "unit-secret" + + +@pytest.fixture(autouse=True) +def _env(monkeypatch): + monkeypatch.setenv("DLW_SYSTEM_JWT_SECRET", SECRET) + monkeypatch.setenv("DLW_TLS_TRUSTED_PROXY", "1") + from dlw.config import get_settings + get_settings.cache_clear() + yield + get_settings.cache_clear() + + +@pytest.fixture +async def app_client(ephemeral_ca, engine, monkeypatch): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + app = make_app_with_state(ephemeral_ca, enrollment_token="e") + + class _D: + id = "modelscope" + + def download_url(self, file): + return "https://www.modelscope.cn/x" + + def auth_token(self, t): + from dlw.sources.base import SourceToken + return SourceToken(scheme="none") + + class _Reg: + def get(self, sid): + return _D() if sid == "modelscope" else None + + app.state.source_registry = _Reg() + import dlw.api.source_proxy as sp + monkeypatch.setattr(sp, "_make_source_client", lambda _t: httpx.AsyncClient( + transport=httpx.MockTransport( + lambda r: httpx.Response(200, content=b"HELLO", + headers={"Content-Length": "5"})))) + async with AsyncClient(transport=ASGITransport(app=app), + base_url="http://test") as c: + yield app, c, f + # Clean up so this fixture's seeded Tenant(id=1) etc. don't leak into a + # later module's non-clean-slate _bootstrap (session-scoped DB). + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_proxy_streams_from_assigned_source(app_client): + app, client, f = app_client + from dlw.db.models.task import DownloadTask, FileSubTask + reg = await register_test_executor(client, enrollment_token="e") + async with f() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="downloading") + s.add(t) + await s.flush() + tok = uuid.uuid4() + sub = FileSubTask(task_id=t.id, tenant_id=1, filename="m", + file_size=5, status="assigned", + executor_id=reg["executor_id"], + executor_epoch=reg["epoch"], assignment_token=tok, + source_id="modelscope") + s.add(sub) + await s.commit() + sub_id = sub.id + h = {**executor_request_headers(reg), "X-Assignment-Token": str(tok)} + r = await client.get(f"/api/v1/source-proxy/subtask/{sub_id}", headers=h) + assert r.status_code == 200 + assert r.content == b"HELLO" diff --git a/tests/api/test_subtasks.py b/tests/api/test_subtasks.py index 16aab76..10d274a 100644 --- a/tests/api/test_subtasks.py +++ b/tests/api/test_subtasks.py @@ -25,6 +25,7 @@ async def _bootstrap(engine): from dlw.db.models.storage import StorageBackend from dlw.db.models.tenant import Project, Tenant, User async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as s: diff --git a/tests/api/test_tasks.py b/tests/api/test_tasks.py index 823bf8a..e5f24d7 100644 --- a/tests/api/test_tasks.py +++ b/tests/api/test_tasks.py @@ -19,6 +19,7 @@ async def _bootstrap(engine): from dlw.db.models.tenant import Project, Tenant, User async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) factory = async_sessionmaker(engine, expire_on_commit=False) diff --git a/tests/api/test_tasks_tenant_scope.py b/tests/api/test_tasks_tenant_scope.py index 70c0116..b4ef6be 100644 --- a/tests/api/test_tasks_tenant_scope.py +++ b/tests/api/test_tasks_tenant_scope.py @@ -17,6 +17,7 @@ async def _bootstrap(engine): from dlw.db.models.storage import StorageBackend from dlw.db.models.tenant import Project, Tenant, User async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) factory = async_sessionmaker(engine, expire_on_commit=False) async with factory() as s: diff --git a/tests/conftest.py b/tests/conftest.py index 4d54ab5..a38da0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -306,6 +306,21 @@ async def stream_hf(self, *, subtask_id, assignment_token, ) as resp: yield resp + @_acm + async def stream_source(self, *, subtask_id, assignment_token, + range_header=None): + headers = {"X-Assignment-Token": str(assignment_token)} + if range_header: + headers["Range"] = range_header + async with _httpx.AsyncClient( + transport=self._transport, base_url="http://fake-controller", + ) as client: + async with client.stream( + "GET", f"/api/v1/source-proxy/subtask/{subtask_id}", + headers=headers, + ) as resp: + yield resp + return _FakeControllerClient() @@ -342,6 +357,12 @@ def make_app_with_state( app.state.settings = _gs() from dlw.authz.enforcer import build_enforcer app.state.casbin = build_enforcer(grants=[]) + from dlw.sources.name_resolver import NameResolver + from dlw.sources.registry import load_registry + _s2 = app.state.settings + app.state.source_registry = load_registry( + _s2.sources_yaml_path, hf_token=_s2.hf_token) + app.state.name_resolver = NameResolver.from_file(_s2.resolver_rules_path) app.state.ca = ephemeral_ca["ca"] app.state.jwt_keypair = ephemeral_ca["jwt_keypair"] app.state.nonce_store = NonceStore(maxsize=1000, ttl_seconds=300) diff --git a/tests/db/test_alembic.py b/tests/db/test_alembic.py index 4e34230..d348fc7 100644 --- a/tests/db/test_alembic.py +++ b/tests/db/test_alembic.py @@ -59,7 +59,10 @@ def _find_uv() -> str: "file_subtasks", "projects", "quota_snapshots", + "source_blacklist", + "source_speed_samples", "storage_backends", + "subtask_chunks", "tenants", "usage_records", "users", diff --git a/tests/db/test_p3sp2_migration.py b/tests/db/test_p3sp2_migration.py new file mode 100644 index 0000000..8731b54 --- /dev/null +++ b/tests/db/test_p3sp2_migration.py @@ -0,0 +1,32 @@ +"""SP2 migration: 3 tables + task/subtask source columns.""" +from __future__ import annotations + +import pytest +from sqlalchemy import text + +import dlw.db.models # noqa: F401 + +pytestmark = pytest.mark.slow + + +async def test_tables_and_columns(engine): + from dlw.db.base import Base + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) + await conn.run_sync(Base.metadata.create_all) + names = {r[0] for r in await conn.execute(text( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema='public'"))} + assert {"subtask_chunks", "source_speed_samples", + "source_blacklist"} <= names + cols = {r[0] for r in await conn.execute(text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='download_tasks'"))} + assert {"source_strategy", "source_blacklist", + "trust_non_hf_sha256"} <= cols + scols = {r[0] for r in await conn.execute(text( + "SELECT column_name FROM information_schema.columns " + "WHERE table_name='file_subtasks'"))} + assert {"source_id", "is_chunked"} <= scols + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) diff --git a/tests/e2e/test_executor_e2e.py b/tests/e2e/test_executor_e2e.py index 2815b9a..874413d 100644 --- a/tests/e2e/test_executor_e2e.py +++ b/tests/e2e/test_executor_e2e.py @@ -175,16 +175,18 @@ def hf_handler(request: httpx.Request) -> httpx.Response: downloader = HfS3StreamDownloader( settings=settings, client=executor_client, ) - # W3b: the executor fetches HF bytes through the controller's - # reverse proxy. Install the HF MockTransport on the controller's - # HF client factory (not the downloader). + # SP2: the executor now fetches via the generalized /source-proxy + # (a subtask with no SP2-scheduled source_id defaults to the + # huggingface driver — same HF URL as the old /hf-proxy). Install + # the HF MockTransport on the source-proxy outbound client factory + # (and keep the legacy hf-proxy seam patched for any direct use). import dlw.api.hf_proxy as _hf_proxy_mod - monkeypatch.setattr( - _hf_proxy_mod, "_make_hf_client", - lambda timeout_seconds: httpx.AsyncClient( - transport=hf_transport, follow_redirects=True, - ), + import dlw.api.source_proxy as _source_proxy_mod + _mk = lambda timeout_seconds: httpx.AsyncClient( # noqa: E731 + transport=hf_transport, follow_redirects=True, ) + monkeypatch.setattr(_hf_proxy_mod, "_make_hf_client", _mk) + monkeypatch.setattr(_source_proxy_mod, "_make_source_client", _mk) runner = ExecutorRunner( settings=settings, client=executor_client, diff --git a/tests/e2e/test_multi_source.py b/tests/e2e/test_multi_source.py new file mode 100644 index 0000000..d046554 --- /dev/null +++ b/tests/e2e/test_multi_source.py @@ -0,0 +1,113 @@ +"""E2E-002: auto_balance planning + HF-authority pause (Phase 3 SP2). + +End-to-end at the planner+DB level (no live mirrors): a task with HF + a +faster ModelScope-style fake source gets files assigned to the faster +source, and an HF-absent task without trust pauses (INVARIANT 13).""" +from __future__ import annotations + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_scheduler import plan_task_sources +from dlw.sources.base import SourceFile, SourceManifest + +pytestmark = pytest.mark.slow + + +class _Drv: + def __init__(self, sid, files): + self.id = sid + self.provides_sha256 = sid in ("huggingface", "hf_mirror") + self._f = files + + async def resolve(self, repo, rev): + return SourceManifest(self.id, repo, rev, self._f, + has_lfs_sha256=any(f.sha256 for f in self._f)) + + +class _Reg: + def __init__(self, d): + self._d = d + + def enabled_ids(self): + return list(self._d) + + def get(self, s): + return self._d.get(s) + + +class _Id: + def resolve(self, sid, repo): + return repo + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_auto_balance_prefers_fast_source(factory): + files = [SourceFile("a.safetensors", 50, "a" * 64, "r"), + SourceFile("b.safetensors", 50, "b" * 64, "r")] + reg = _Reg({"huggingface": _Drv("huggingface", files), + "modelscope": _Drv("modelscope", files)}) + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(t) + await s.flush() + for f in files: + s.add(FileSubTask(task_id=t.id, tenant_id=1, filename=f.filename, + file_size=f.size, expected_sha256=f.sha256, + status="pending")) + await s.commit() + await plan_task_sources(s, t, registry=reg, resolver=_Id(), + speeds={"huggingface": 50.0, + "modelscope": 5000.0}, + chunk_min_mb=100) + await s.commit() + subs = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == t.id))).scalars().all() + assert all(x.source_id == "modelscope" for x in subs) # HF too slow + + +async def test_hf_unavailable_pauses(factory): + files = [SourceFile("a", 10, None, "r")] + reg = _Reg({"modelscope": _Drv("modelscope", files)}) + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + trust_non_hf_sha256=False) + s.add(t) + await s.flush() + s.add(FileSubTask(task_id=t.id, tenant_id=1, filename="a", + file_size=10, status="pending")) + await s.commit() + await plan_task_sources(s, t, registry=reg, resolver=_Id(), + speeds={"modelscope": 900.0}, chunk_min_mb=100) + await s.commit() + assert t.status == "paused_external" + assert t.error_message == "no_sha256_authority" diff --git a/tests/executor/test_stream_source.py b/tests/executor/test_stream_source.py new file mode 100644 index 0000000..fb46ee6 --- /dev/null +++ b/tests/executor/test_stream_source.py @@ -0,0 +1,34 @@ +"""ControllerClient.stream_source targets /source-proxy (Phase 3 SP2).""" +from __future__ import annotations + +import uuid + +import httpx + +from dlw.executor.client import ControllerClient +from tests.conftest import make_fake_auth_state + + +async def test_stream_source_hits_source_proxy(tmp_path): + seen = {} + + def handler(request: httpx.Request) -> httpx.Response: + seen["path"] = request.url.path + seen["range"] = request.headers.get("Range") + return httpx.Response(200, content=b"DATA") + + c = ControllerClient( + "http://ctrl", + auth_state=make_fake_auth_state(tmp_path), + _transport=httpx.MockTransport(handler)) + sid = uuid.uuid4() + tok = uuid.uuid4() + async with c.stream_source(subtask_id=sid, assignment_token=tok, + range_header="bytes=0-3") as resp: + assert resp.status_code == 200 + body = b"" + async for b in resp.aiter_bytes(): + body += b + assert body == b"DATA" + assert seen["path"] == f"/api/v1/source-proxy/subtask/{sid}" + assert seen["range"] == "bytes=0-3" diff --git a/tests/services/test_sha_authority.py b/tests/services/test_sha_authority.py new file mode 100644 index 0000000..5661f3f --- /dev/null +++ b/tests/services/test_sha_authority.py @@ -0,0 +1,87 @@ +"""Non-HF completion verified vs HF expected_sha256 → blacklist on mismatch.""" +from __future__ import annotations + +import uuid + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SourceBlacklist +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.scheduler import complete_subtask + +pytestmark = pytest.mark.slow + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_non_hf_sha_mismatch_blacklists(factory): + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="downloading") + s.add(t) + await s.flush() + tok = uuid.uuid4() + sub = FileSubTask(task_id=t.id, tenant_id=1, filename="m", + file_size=4, expected_sha256="c" * 64, + status="assigned", assignment_token=tok, + source_id="modelscope") + s.add(sub) + await s.flush() + sid = sub.id + done, _ = await complete_subtask( + s, sid, final_status="succeeded", actual_sha256="d" * 64, + bytes_downloaded=4, error=None, assignment_token=tok) + await s.commit() + assert done.status == "failed" + bl = (await s.execute(select(SourceBlacklist).where( + SourceBlacklist.source_id == "modelscope"))).scalars().all() + assert len(bl) == 1 and bl[0].filename == "m" + + +async def test_hf_source_mismatch_not_blacklisted(factory): + """A mismatch on the huggingface source itself must NOT blacklist HF.""" + async with factory() as s: + t = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="downloading") + s.add(t) + await s.flush() + tok = uuid.uuid4() + sub = FileSubTask(task_id=t.id, tenant_id=1, filename="m", + file_size=4, expected_sha256="c" * 64, + status="assigned", assignment_token=tok, + source_id="huggingface") + s.add(sub) + await s.flush() + sid = sub.id + await complete_subtask( + s, sid, final_status="succeeded", actual_sha256="d" * 64, + bytes_downloaded=4, error=None, assignment_token=tok) + await s.commit() + bl = (await s.execute(select(SourceBlacklist))).scalars().all() + assert bl == [] diff --git a/tests/services/test_source_blacklist.py b/tests/services/test_source_blacklist.py new file mode 100644 index 0000000..195fef9 --- /dev/null +++ b/tests/services/test_source_blacklist.py @@ -0,0 +1,48 @@ +"""Source blacklist transitions (Phase 3 SP2; doc §1.7).""" +from __future__ import annotations + +from datetime import UTC, datetime, timedelta + +import pytest +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SourceBlacklist +from dlw.services.source_blacklist import ( + blacklist_file, + is_blacklisted, +) + +pytestmark = pytest.mark.slow + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + yield async_sessionmaker(engine, expire_on_commit=False) + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +async def test_blacklist_and_check(factory): + async with factory() as s: + await blacklist_file(s, source_id="modelscope", repo_id="o/r", + filename="m.safetensors", hours=24, + reason="sha_mismatch") + await s.commit() + assert await is_blacklisted(s, "modelscope", "o/r", + "m.safetensors") is True + assert await is_blacklisted(s, "modelscope", "o/r", + "other.bin") is False + + +async def test_expired_not_blacklisted(factory): + async with factory() as s: + s.add(SourceBlacklist(source_id="modelscope", repo_id="o/r", + filename="m", reason="x", + until=datetime.now(UTC) - timedelta(hours=1))) + await s.commit() + assert await is_blacklisted(s, "modelscope", "o/r", "m") is False diff --git a/tests/services/test_source_combo.py b/tests/services/test_source_combo.py new file mode 100644 index 0000000..26cdbfc --- /dev/null +++ b/tests/services/test_source_combo.py @@ -0,0 +1,31 @@ +"""LPT greedy + optimal-combo (Phase 3 SP2; doc §1.6/§1.8, OR-V21-04).""" +from __future__ import annotations + +from dlw.services.source_combo import assign_files_lpt, solve_optimal_combo + + +def test_lpt_balances_by_completion_time(): + files = {"a": 100, "b": 100, "c": 50} + speeds = {"s1": 10.0, "s2": 5.0} + assign = assign_files_lpt(files, speeds) + assert set(assign.values()) <= {"s1", "s2"} + assert assign["a"] == "s1" + + +def test_lpt_single_source_degenerate(): + assign = assign_files_lpt({"a": 1, "b": 2}, {"only": 7.0}) + assert assign == {"a": "only", "b": "only"} + + +def test_combo_excludes_slow_source_by_overhead(): + files = {"f": 1_000_000_000} + speeds = {"fast": 1_000_000_000.0, "slow": 1.0} + combo = solve_optimal_combo(speeds, files, overhead_pct=2.0) + assert combo == ["fast"] + + +def test_combo_uses_both_when_comparable(): + files = {"a": 100, "b": 100} + speeds = {"s1": 10.0, "s2": 10.0} + combo = solve_optimal_combo(speeds, files, overhead_pct=2.0) + assert set(combo) == {"s1", "s2"} diff --git a/tests/services/test_source_scheduler.py b/tests/services/test_source_scheduler.py new file mode 100644 index 0000000..62bff45 --- /dev/null +++ b/tests/services/test_source_scheduler.py @@ -0,0 +1,171 @@ +"""plan_task_sources: resolve→assign→persist + HF-authority gate (SP2).""" +from __future__ import annotations + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import async_sessionmaker + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base +from dlw.db.models.source import SubtaskChunk +from dlw.db.models.task import DownloadTask, FileSubTask +from dlw.services.source_scheduler import plan_task_sources +from dlw.sources.base import SourceFile, SourceManifest + +pytestmark = pytest.mark.slow + + +class _FakeDriver: + def __init__(self, sid, files, sha): + self.id = sid + self.provides_sha256 = sha + self._files = files + + async def resolve(self, repo_id, revision): + return SourceManifest(self.id, repo_id, revision, self._files, + has_lfs_sha256=any( + f.sha256 for f in self._files)) + + +class _FakeReg: + def __init__(self, drivers): + self._d = drivers + + def enabled_ids(self): + return list(self._d) + + def get(self, sid): + return self._d.get(sid) + + +class _IdResolver: + def resolve(self, source_id, hf_repo_id): + return hf_repo_id + + +@pytest.fixture +async def factory(engine): + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + await c.run_sync(Base.metadata.create_all) + f = async_sessionmaker(engine, expire_on_commit=False) + from dlw.db.models.storage import StorageBackend + from dlw.db.models.tenant import Project, Tenant, User + async with f() as s: + s.add(Tenant(id=1, slug="t", display_name="T")) + await s.flush() + s.add_all([Project(id=1, tenant_id=1, name="d"), + User(id=1, tenant_id=1, oidc_subject="u", email="e", + role="tenant_operator"), + StorageBackend(id=1, tenant_id=1, name="s", + backend_type="s3", config_encrypted=b"")]) + await s.commit() + yield f + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) + + +def _files(): + return [SourceFile("model.safetensors", 200 * 1024 * 1024, "a" * 64, + "ref"), + SourceFile("config.json", 10, None, "ref2")] + + +async def test_plan_assigns_and_persists(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(task) + await s.flush() + for f in _files(): + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename=f.filename, + file_size=f.size, expected_sha256=f.sha256, + status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True), + "modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources( + s, task, registry=reg, resolver=_IdResolver(), + speeds={"huggingface": 50.0, "modelscope": 900.0}, + chunk_min_mb=100) + await s.commit() + subs = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalars().all() + assert all(x.source_id in {"huggingface", "modelscope"} for x in subs) + big = next(x for x in subs if x.filename == "model.safetensors") + assert big.is_chunked is True + chunks = (await s.execute(select(SubtaskChunk).where( + SubtaskChunk.subtask_id == big.id))).scalars().all() + assert len(chunks) >= 2 + + +async def test_hf_absent_pauses_when_not_trusted(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + trust_non_hf_sha256=False) + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename="c.json", + file_size=10, expected_sha256=None, + status="pending")) + await s.commit() + reg = _FakeReg({"modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"modelscope": 900.0}, chunk_min_mb=100) + await s.commit() + assert task.status == "paused_external" + assert task.error_message == "no_sha256_authority" + + +async def test_no_sha_file_pinned_to_huggingface(factory): + """INVARIANT 12 (spec ruling 6a): a file with expected_sha256=None must + stay on huggingface even when a faster non-HF source covers it.""" + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling") + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, + filename="config.json", file_size=10, + expected_sha256=None, status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True), + "modelscope": _FakeDriver("modelscope", _files(), + False)}) + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"huggingface": 1.0, + "modelscope": 9000.0}, + chunk_min_mb=100) + await s.commit() + sub = (await s.execute(select(FileSubTask).where( + FileSubTask.task_id == task.id))).scalar_one() + assert sub.source_id == "huggingface" and sub.is_chunked is False + + +async def test_pin_modelscope_unreachable_pauses(factory): + async with factory() as s: + task = DownloadTask(tenant_id=1, project_id=1, owner_user_id=1, + repo_id="o/r", revision="abc", storage_id=1, + path_template="t", status="scheduling", + source_strategy="pin_modelscope") + s.add(task) + await s.flush() + s.add(FileSubTask(task_id=task.id, tenant_id=1, filename="m", + file_size=10, expected_sha256="a" * 64, + status="pending")) + await s.commit() + reg = _FakeReg({"huggingface": _FakeDriver("huggingface", _files(), + True)}) + await plan_task_sources(s, task, registry=reg, resolver=_IdResolver(), + speeds={"huggingface": 50.0}, chunk_min_mb=100) + await s.commit() + assert task.status == "paused_external" + assert task.error_message == "pinned_source_unavailable" diff --git a/tests/services/test_source_speed.py b/tests/services/test_source_speed.py new file mode 100644 index 0000000..3f56570 --- /dev/null +++ b/tests/services/test_source_speed.py @@ -0,0 +1,51 @@ +"""Speed EWMA fusion + controller-side probe (Phase 3 SP2; doc §1.7/§1.8).""" +from __future__ import annotations + +import httpx + +from dlw.services.source_speed import ( + fuse_ewma, + pick_probe_size_bytes, + probe_source_speed, +) +from dlw.sources.base import SourceFile + + +def test_fuse_no_history_uses_live(): + assert fuse_ewma(live=1000.0, hist=None, hist_weight=0.3) == 1000.0 + + +def test_fuse_blends(): + assert fuse_ewma(live=1000.0, hist=500.0, hist_weight=0.3) == 850.0 + + +def test_probe_size(): + assert pick_probe_size_bytes(probe_size_mb=32) == 32 * 1024 * 1024 + + +class _Drv: + def download_url(self, f): + return "https://src/x" + + def auth_token(self, t): + from dlw.sources.base import SourceToken + return SourceToken(scheme="none") + + +async def test_probe_returns_positive_speed(): + transport = httpx.MockTransport( + lambda r: httpx.Response(206, content=b"x" * 4096)) + bps = await probe_source_speed( + _Drv(), SourceFile("m", 4096, None, "ref"), + probe_bytes=4096, timeout_s=5.0, hf_token=None, transport=transport) + assert bps > 0.0 + + +async def test_probe_failure_returns_zero(): + def boom(r): + raise httpx.ConnectError("down") + bps = await probe_source_speed( + _Drv(), SourceFile("m", 4096, None, "ref"), + probe_bytes=4096, timeout_s=5.0, hf_token=None, + transport=httpx.MockTransport(boom)) + assert bps == 0.0 diff --git a/tests/sources/__init__.py b/tests/sources/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/sources/test_base.py b/tests/sources/test_base.py new file mode 100644 index 0000000..620d51c --- /dev/null +++ b/tests/sources/test_base.py @@ -0,0 +1,29 @@ +"""SourceDriver Protocol + dataclasses (Phase 3 SP2).""" +from __future__ import annotations + +from dlw.sources.base import ( + SourceFile, + SourceHealth, + SourceManifest, + SourceToken, +) + + +def test_sourcefile_defaults(): + f = SourceFile(filename="model.safetensors", size=10, sha256=None, + download_ref="r") + assert f.filename == "model.safetensors" and f.sha256 is None + + +def test_manifest_holds_files(): + m = SourceManifest(source_id="huggingface", repo_id_in_source="o/r", + revision_in_source="abc", files=[ + SourceFile("a", 1, "x" * 64, "ref")], + has_lfs_sha256=True) + assert m.source_id == "huggingface" and len(m.files) == 1 + + +def test_health_and_token(): + assert SourceHealth(ok=True, latency_ms=12.0).ok is True + t = SourceToken(scheme="bearer", value="secret") + assert t.value == "secret" and "secret" not in repr(t) diff --git a/tests/sources/test_hf_drivers.py b/tests/sources/test_hf_drivers.py new file mode 100644 index 0000000..829ba47 --- /dev/null +++ b/tests/sources/test_hf_drivers.py @@ -0,0 +1,48 @@ +"""HF + hf_mirror drivers (Phase 3 SP2).""" +from __future__ import annotations + +import pytest + +from dlw.services.hf_metadata import RepoFile +from dlw.sources.hf_mirror import HfMirrorDriver +from dlw.sources.huggingface import HuggingFaceDriver + + +@pytest.fixture +def _patch_list(monkeypatch): + async def fake(repo_id, revision, *, hf_endpoint, hf_token): + assert revision == "abc" + return [RepoFile(path="model.safetensors", size=64, sha256="a" * 64), + RepoFile(path="config.json", size=4, sha256=None)] + monkeypatch.setattr("dlw.sources.huggingface.list_repo_tree", fake) + monkeypatch.setattr("dlw.sources.hf_mirror.list_repo_tree", fake) + + +async def test_hf_resolve(_patch_list): + d = HuggingFaceDriver(base_url="https://huggingface.co", hf_token="tok") + m = await d.resolve("o/r", "abc") + assert m is not None + assert m.source_id == "huggingface" and m.has_lfs_sha256 is True + assert {f.filename for f in m.files} == {"model.safetensors", "config.json"} + assert d.provides_sha256 is True + assert d.download_url(m.files[0]).endswith( + "/o/r/resolve/abc/model.safetensors") + assert d.auth_token("tok").value == "tok" + + +async def test_hf_mirror_no_token_and_base(_patch_list): + d = HfMirrorDriver(base_url="https://hf-mirror.com") + m = await d.resolve("o/r", "abc") + assert m.source_id == "hf_mirror" + assert d.download_url(m.files[0]).startswith("https://hf-mirror.com/") + assert d.auth_token("tok").scheme == "none" + + +async def test_hf_mirror_gated_returns_none(monkeypatch): + from dlw.services.hf_metadata import HfPrivateOrAuthRequired + + async def gated(*a, **k): + raise HfPrivateOrAuthRequired("gated") + monkeypatch.setattr("dlw.sources.hf_mirror.list_repo_tree", gated) + d = HfMirrorDriver(base_url="https://hf-mirror.com") + assert await d.resolve("o/gated", "abc") is None diff --git a/tests/sources/test_modelscope_driver.py b/tests/sources/test_modelscope_driver.py new file mode 100644 index 0000000..60f2eaf --- /dev/null +++ b/tests/sources/test_modelscope_driver.py @@ -0,0 +1,45 @@ +"""ModelScope driver (Phase 3 SP2).""" +from __future__ import annotations + +import httpx +import pytest + +from dlw.sources.modelscope import ModelScopeDriver + + +def _handler(request: httpx.Request) -> httpx.Response: + assert "modelscope.cn" in str(request.url) + if "/repo?Revision=" in str(request.url) and "FilePath" not in str(request.url): + return httpx.Response(200, json={"Data": {"Files": [ + {"Path": "model.safetensors", "Size": 64}, + {"Path": "config.json", "Size": 4}]}}) + return httpx.Response(404) + + +@pytest.fixture +def _drv(): + return ModelScopeDriver( + base_url="https://www.modelscope.cn", + transport=httpx.MockTransport(_handler)) + + +async def test_modelscope_resolve_no_sha(_drv): + m = await _drv.resolve("qwen/Qwen3-7B", "v1") + assert m is not None + assert m.source_id == "modelscope" and m.has_lfs_sha256 is False + assert all(f.sha256 is None for f in m.files) + assert {f.filename for f in m.files} == {"model.safetensors", "config.json"} + assert _drv.provides_sha256 is False + + +async def test_modelscope_download_url(_drv): + m = await _drv.resolve("qwen/Qwen3-7B", "v1") + url = _drv.download_url(m.files[0]) + assert "FilePath=model.safetensors" in url and "Revision=v1" in url + + +async def test_modelscope_missing_repo_returns_none(): + d = ModelScopeDriver( + base_url="https://www.modelscope.cn", + transport=httpx.MockTransport(lambda r: httpx.Response(404))) + assert await d.resolve("no/such", "v1") is None diff --git a/tests/sources/test_name_resolver.py b/tests/sources/test_name_resolver.py new file mode 100644 index 0000000..f1a7a44 --- /dev/null +++ b/tests/sources/test_name_resolver.py @@ -0,0 +1,46 @@ +"""NameResolver 3-tier (Phase 3 SP2; doc §1.5).""" +from __future__ import annotations + +from dlw.sources.name_resolver import NameResolver + +_RULES = """ +identity_organizations: [deepseek-ai, Qwen, THUDM] +aliases: + - hf_org: meta-llama + modelscope_org: LLM-Research + transform: "Meta-{name}" +per_model_overrides: + - hf: "weird-org/weird-model" + modelscope: "diff-org/diff-name" +""" + + +def _r(tmp_path): + p = tmp_path / "rr.yaml" + p.write_text(_RULES, encoding="utf-8") + return NameResolver.from_file(str(p)) + + +def test_huggingface_is_always_identity(tmp_path): + r = _r(tmp_path) + assert r.resolve("huggingface", "any-org/any-model") == "any-org/any-model" + + +def test_identity_org(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "deepseek-ai/DeepSeek-V3") == "deepseek-ai/DeepSeek-V3" + + +def test_alias_transform(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "meta-llama/Llama-3.1-8B") == "LLM-Research/Meta-Llama-3.1-8B" + + +def test_per_model_override(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "weird-org/weird-model") == "diff-org/diff-name" + + +def test_unknown_returns_none(tmp_path): + r = _r(tmp_path) + assert r.resolve("modelscope", "rando-org/rando-model") is None diff --git a/tests/sources/test_registry.py b/tests/sources/test_registry.py new file mode 100644 index 0000000..fd360d0 --- /dev/null +++ b/tests/sources/test_registry.py @@ -0,0 +1,45 @@ +"""Source registry from sources.yaml (Phase 3 SP2).""" +from __future__ import annotations + +from dlw.sources.registry import load_registry + +_YAML = """ +sources: + - id: huggingface + enabled: true + driver: huggingface + config: {base_url: https://huggingface.co} + - id: hf_mirror + enabled: true + driver: hf_mirror + config: {base_url: https://hf-mirror.com} + - id: modelscope + enabled: false + driver: modelscope + config: {base_url: https://www.modelscope.cn} + - id: corp + enabled: true + driver: s3_mirror + config: {} +regional_defaults: + cn-north: [hf_mirror, modelscope, huggingface] +""" + + +def test_only_enabled_supported(tmp_path): + p = tmp_path / "s.yaml" + p.write_text(_YAML, encoding="utf-8") + reg = load_registry(str(p), hf_token="tk") + assert set(reg.enabled_ids()) == {"huggingface", "hf_mirror"} + assert reg.get("huggingface").id == "huggingface" + assert reg.get("missing") is None + assert reg.regional_defaults["cn-north"][0] == "hf_mirror" + + +def test_modelscope_enabled(tmp_path): + p = tmp_path / "s.yaml" + p.write_text(_YAML.replace("id: modelscope\n enabled: false", + "id: modelscope\n enabled: true"), + encoding="utf-8") + reg = load_registry(str(p), hf_token=None) + assert "modelscope" in reg.enabled_ids() diff --git a/tests/test_config.py b/tests/test_config.py index 366b1e8..f323930 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -98,3 +98,18 @@ def test_sp1_auth_settings_env_override(monkeypatch): assert s.system_jwt_secret == "s3cr3t" assert s.system_admin_token == "svc-tok" get_settings.cache_clear() + + +def test_sp2_source_settings_defaults(): + from dlw.config import get_settings + get_settings.cache_clear() + s = get_settings() + assert s.sources_yaml_path == "config/sources.yaml" + assert s.resolver_rules_path == "config/resolver-rules.yaml" + assert s.probe_size_mb == 32 + assert s.probe_timeout_s == 8.0 + assert s.chunk_level_min_file_mb == 100 + assert s.speed_ewma_alpha == 0.3 + assert s.sha_mismatch_blacklist_hours == 24 + assert s.rebalance_interval_seconds == 60.0 + get_settings.cache_clear() diff --git a/tests/test_lifespan_state.py b/tests/test_lifespan_state.py index cb1c34e..6265b59 100644 --- a/tests/test_lifespan_state.py +++ b/tests/test_lifespan_state.py @@ -35,6 +35,7 @@ async def test_real_lifespan_sets_casbin_and_settings( # the enforcer must actually enforce (deny a viewer POST) assert app.state.casbin.enforce( "role:tenant_viewer", 1, "/api/v1/tasks", "POST", 1) is False + assert app.state.source_registry is not None get_settings.cache_clear() async with engine.begin() as conn: diff --git a/tests/test_sp2_lifespan.py b/tests/test_sp2_lifespan.py new file mode 100644 index 0000000..d4630fd --- /dev/null +++ b/tests/test_sp2_lifespan.py @@ -0,0 +1,28 @@ +"""Real lifespan bootstraps source_registry + name_resolver (SP2).""" +from __future__ import annotations + +import pytest + +import dlw.db.models # noqa: F401 +from dlw.db.base import Base + +pytestmark = pytest.mark.slow + + +async def test_lifespan_sets_source_state(engine, tmp_path, monkeypatch): + async with engine.begin() as c: + await c.run_sync(Base.metadata.create_all) + monkeypatch.setenv("DLW_AUTH_DEV_MODE", "true") + monkeypatch.setenv("DLW_CA_DIR", str(tmp_path / "ca")) + from dlw.config import get_settings + get_settings.cache_clear() + from dlw.main import create_app, lifespan + from dlw.sources.registry import SourceRegistry + app = create_app() + async with lifespan(app): + assert isinstance(app.state.source_registry, SourceRegistry) + assert app.state.name_resolver is not None + assert "huggingface" in app.state.source_registry.enabled_ids() + get_settings.cache_clear() + async with engine.begin() as c: + await c.run_sync(Base.metadata.drop_all) diff --git a/uv.lock b/uv.lock index 86a8e1a..dcdcc2a 100644 --- a/uv.lock +++ b/uv.lock @@ -400,6 +400,7 @@ dependencies = [ { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyjwt", extra = ["crypto"] }, + { name = "pyyaml" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "structlog" }, { name = "tenacity" }, @@ -431,6 +432,7 @@ requires-dist = [ { name = "pydantic", specifier = ">=2.9,<2.11" }, { name = "pydantic-settings", specifier = ">=2.6,<2.7" }, { name = "pyjwt", extras = ["crypto"], specifier = ">=2.9,<3.0" }, + { name = "pyyaml", specifier = ">=6,<7" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<2.1" }, { name = "structlog", specifier = ">=24.4,<24.5" }, { name = "tenacity", specifier = ">=9.0,<10.0" },